-
Notifications
You must be signed in to change notification settings - Fork 27
/
generate.go
233 lines (220 loc) · 7.57 KB
/
generate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
package pggen
import (
"context"
"errors"
"fmt"
"github.com/jackc/pgx/v4"
"github.com/jschaf/pggen/internal/ast"
"github.com/jschaf/pggen/internal/codegen"
"github.com/jschaf/pggen/internal/codegen/golang"
"github.com/jschaf/pggen/internal/errs"
"github.com/jschaf/pggen/internal/parser"
"github.com/jschaf/pggen/internal/pgdocker"
"github.com/jschaf/pggen/internal/pginfer"
gotok "go/token"
"log/slog"
"os"
"path/filepath"
"time"
)
// Lang is a supported codegen language.
type Lang string
const (
LangGo Lang = "go"
)
// GenerateOptions are the unparsed options that controls the generated Go code.
type GenerateOptions struct {
// What language to generate code in.
Language Lang
// The connection string to the running Postgres database to use to get type
// information for each query in QueryFiles.
//
// Must be parseable by pgconn.ParseConfig, like:
//
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=foo_db sslmode=verify-ca
//
// # Example URL
// postgres://jack:[email protected]:5432/foo_db?sslmode=verify-ca
ConnString string
// Generate code for each of the SQL query file paths.
QueryFiles []string
// Schema files to run on Postgres init. Can be *.sql, *.sql.gz, or executable
// *.sh files .
SchemaFiles []string
// The name of the Go package for the file. If empty, defaults to the
// directory name.
GoPackage string
// Directory to write generated files. Writes one file for each query file.
// If more than one query file, also writes querier.go.
OutputDir string
// A map of lowercase acronyms to the upper case equivalent, like:
// "api" => "API", or "apis" => "APIs".
Acronyms map[string]string
// A map from a Postgres type name to a fully qualified Go type.
TypeOverrides map[string]string
// What log level to log at.
LogLevel slog.Level
// How many params to inline when calling querier methods.
// Set to 0 to always create a struct for params.
InlineParamCount int
}
// Generate generates language specific code to safely wrap each SQL
// ast.SourceQuery in opts.QueryFiles.
//
// Generate must only be called once per output directory.
func Generate(opts GenerateOptions) (mErr error) {
// Preconditions.
if opts.Language == "" {
return fmt.Errorf("generate language must be set; got empty string")
}
if len(opts.QueryFiles) == 0 {
return fmt.Errorf("got 0 query files, at least 1 must be set")
}
if opts.OutputDir == "" {
return fmt.Errorf("output dir must be set")
}
// Postgres connection.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
pgConn, errEnricher, cleanup, err := connectPostgres(ctx, opts)
if err != nil {
return fmt.Errorf("connect postgres: %w", err)
}
defer errs.Capture(&mErr, cleanup, "close postgres connection")
// Parse queries.
inferrer := pginfer.NewInferrer(pgConn)
queryFiles, err := parseQueryFiles(opts.QueryFiles, inferrer)
if err != nil {
return errEnricher(err)
}
// Codegen.
if opts.Acronyms == nil {
opts.Acronyms = make(map[string]string, 1)
}
if _, ok := opts.Acronyms["id"]; !ok {
opts.Acronyms["id"] = "ID"
}
switch opts.Language {
case LangGo:
goOpts := golang.GenerateOptions{
GoPkg: opts.GoPackage,
OutputDir: opts.OutputDir,
Acronyms: opts.Acronyms,
TypeOverrides: opts.TypeOverrides,
InlineParamCount: opts.InlineParamCount,
}
if err := golang.Generate(goOpts, queryFiles); err != nil {
return fmt.Errorf("generate go code: %w", err)
}
default:
return fmt.Errorf("unsupported output language %q", opts.Language)
}
return nil
}
// connectPostgres connects to postgres using connString if given or by
// running a Docker postgres container and connecting to that.
func connectPostgres(ctx context.Context, opts GenerateOptions) (*pgx.Conn, func(error) error, func() error, error) {
// Create connection by starting dockerized Postgres.
if opts.ConnString == "" {
client, err := pgdocker.Start(ctx, opts.SchemaFiles)
if err != nil {
return nil, nil, nil, fmt.Errorf("start dockerized postgres: %w", err)
}
stopDocker := func() error { return client.Stop(ctx) }
connStr, err := client.ConnString()
if err != nil {
return nil, nil, nil, fmt.Errorf("get dockerized postgres conn string: %w", err)
}
pgConn, err := pgx.Connect(ctx, connStr)
if err != nil {
return nil, nil, nil, fmt.Errorf("connect to pggen dockerized postgres database: %w", err)
}
errEnricher := func(e error) error {
if e == nil {
return e
}
logs, err := client.GetContainerLogs()
if err != nil {
return errors.Join(e, err)
}
return fmt.Errorf("Container logs for Postgres container:\n\n%s\n\n%w", logs, e)
}
return pgConn, errEnricher, stopDocker, nil
}
// Use existing Postgres.
nopCleanup := func() error { return nil }
nopErrEnricher := func(e error) error { return e }
pgConn, err := pgx.Connect(ctx, opts.ConnString)
if err != nil {
return nil, nil, nil, fmt.Errorf("connect to pggen postgres database: %w", err)
}
// Run SQL init scripts. pgdocker runs these in the other case by copying
// the files into the entrypoint folder. Emulate the behavior for a subset of
// supported files.
for _, script := range opts.SchemaFiles {
if filepath.Ext(script) != ".sql" {
return nil, nopErrEnricher, nopCleanup, fmt.Errorf("cannot run non-sql schema file on Postgres "+
"(*.sh and *.sql.gz files only supported without --postgres-connection): %s", script)
}
bs, err := os.ReadFile(script)
if err != nil {
return nil, nil, nopCleanup, fmt.Errorf("read schema file: %w", err)
}
if _, err := pgConn.Exec(ctx, string(bs)); err != nil {
return nil, nopErrEnricher, nopCleanup, fmt.Errorf("load schema file into Postgres: %w", err)
}
}
return pgConn, nopErrEnricher, nopCleanup, nil
}
func parseQueryFiles(queryFiles []string, inferrer *pginfer.Inferrer) ([]codegen.QueryFile, error) {
files := make([]codegen.QueryFile, len(queryFiles))
for i, file := range queryFiles {
srcPath, err := filepath.Abs(file)
if err != nil {
return nil, fmt.Errorf("resolve absolute path for %q: %w", file, err)
}
queryFile, err := parseQueries(srcPath, inferrer)
if err != nil {
return nil, fmt.Errorf("parse template query file %q: %w", file, err)
}
files[i] = queryFile
}
return files, nil
}
func parseQueries(srcPath string, inferrer *pginfer.Inferrer) (codegen.QueryFile, error) {
astFile, err := parser.ParseFile(gotok.NewFileSet(), srcPath, nil, 0)
if err != nil {
return codegen.QueryFile{}, fmt.Errorf("parse query file %q: %w", srcPath, err)
}
// Check for duplicate query names and bad queries.
srcQueries := make([]*ast.SourceQuery, 0, len(astFile.Queries))
seenNames := make(map[string]struct{}, len(astFile.Queries))
for _, query := range astFile.Queries {
switch query := query.(type) {
case *ast.BadQuery:
return codegen.QueryFile{}, errors.New("parsed bad query instead of erroring")
case *ast.SourceQuery:
if _, ok := seenNames[query.Name]; ok {
return codegen.QueryFile{}, fmt.Errorf("duplicate query name %s", query.Name)
}
seenNames[query.Name] = struct{}{}
srcQueries = append(srcQueries, query)
default:
return codegen.QueryFile{}, fmt.Errorf("unhandled query ast type: %T", query)
}
}
// Infer types.
queries := make([]pginfer.TypedQuery, 0, len(astFile.Queries))
for _, srcQuery := range srcQueries {
typedQuery, err := inferrer.InferTypes(srcQuery)
if err != nil {
return codegen.QueryFile{}, fmt.Errorf("infer typed named query %s: %w", srcQuery.Name, err)
}
queries = append(queries, typedQuery)
}
return codegen.QueryFile{
SourcePath: srcPath,
Queries: queries,
}, nil
}