diff --git a/Makefile b/Makefile index d6111e1c6c..d393ad90b8 100644 --- a/Makefile +++ b/Makefile @@ -9,13 +9,16 @@ install: test: go test ./... +vet: + go vet ./... + test-examples: go test --tags=examples ./... build-endtoend: cd ./internal/endtoend/testdata && go build ./... -test-ci: test-examples build-endtoend +test-ci: test-examples build-endtoend vet regen: sqlc-dev sqlc-gen-json go run ./scripts/regenerate/ diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index baa4b3d842..204da3212d 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -45,20 +45,26 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int rootCmd.SetOut(stdout) rootCmd.SetErr(stderr) - ctx, cleanup, err := tracer.Start(context.Background()) - if err != nil { - fmt.Printf("failed to start trace: %v\n", err) - return 1 + ctx := context.Background() + if debug.Debug.Trace != "" { + tracectx, cleanup, err := tracer.Start(ctx) + if err != nil { + fmt.Printf("failed to start trace: %v\n", err) + return 1 + } + ctx = tracectx + defer cleanup() } - defer cleanup() - if err := rootCmd.ExecuteContext(ctx); err == nil { - return 0 - } - if exitError, ok := err.(*exec.ExitError); ok { - return exitError.ExitCode() + if err := rootCmd.ExecuteContext(ctx); err != nil { + fmt.Fprintf(stderr, "%v\n", err) + if exitError, ok := err.(*exec.ExitError); ok { + return exitError.ExitCode() + } else { + return 1 + } } - return 1 + return 0 } var version string @@ -66,15 +72,14 @@ var version string var versionCmd = &cobra.Command{ Use: "version", Short: "Print the sqlc version number", - Run: func(cmd *cobra.Command, args []string) { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "version").End() - } + RunE: func(cmd *cobra.Command, args []string) error { + defer trace.StartRegion(cmd.Context(), "version").End() if version == "" { fmt.Printf("%s\n", info.Version) } else { fmt.Printf("%s\n", version) } + return nil }, } @@ -82,9 +87,7 @@ var initCmd = &cobra.Command{ Use: "init", Short: "Create an empty sqlc.yaml settings file", RunE: func(cmd *cobra.Command, args []string) error { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "init").End() - } + defer trace.StartRegion(cmd.Context(), "init").End() file := "sqlc.yaml" if f := cmd.Flag("file"); f != nil && f.Changed { file = f.Value.String() @@ -153,26 +156,23 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) { var genCmd = &cobra.Command{ Use: "generate", Short: "Generate Go code from SQL", - Run: func(cmd *cobra.Command, args []string) { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "generate").End() - } + RunE: func(cmd *cobra.Command, args []string) error { + defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) output, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr) if err != nil { - os.Exit(1) - } - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "writefiles").End() + return err } + defer trace.StartRegion(cmd.Context(), "writefiles").End() for filename, source := range output { os.MkdirAll(filepath.Dir(filename), 0755) if err := os.WriteFile(filename, []byte(source), 0644); err != nil { fmt.Fprintf(stderr, "%s: %s\n", filename, err) - os.Exit(1) + return err } } + return nil }, } @@ -194,9 +194,7 @@ var checkCmd = &cobra.Command{ Use: "compile", Short: "Statically check SQL for syntax and type errors", RunE: func(cmd *cobra.Command, args []string) error { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "compile").End() - } + defer trace.StartRegion(cmd.Context(), "compile").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) if _, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil { @@ -239,9 +237,7 @@ var diffCmd = &cobra.Command{ Use: "diff", Short: "Compare the generated files to the existing files", RunE: func(cmd *cobra.Command, args []string) error { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "diff").End() - } + defer trace.StartRegion(cmd.Context(), "diff").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) if err := Diff(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil { diff --git a/internal/cmd/diff.go b/internal/cmd/diff.go index dbab4c6ed8..aa1ddc8788 100644 --- a/internal/cmd/diff.go +++ b/internal/cmd/diff.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/cubicdaiya/gonp" - "github.com/kyleconroy/sqlc/internal/debug" ) func Diff(ctx context.Context, e Env, dir, name string, stderr io.Writer) error { @@ -19,9 +18,7 @@ func Diff(ctx context.Context, e Env, dir, name string, stderr io.Writer) error if err != nil { return err } - if debug.Traced { - defer trace.StartRegion(ctx, "checkfiles").End() - } + defer trace.StartRegion(ctx, "checkfiles").End() var errored bool keys := make([]string, 0, len(output)) diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 602974b9cc..450cb9a60e 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -193,17 +193,12 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer name = sql.Plugin.Plugin } - var packageRegion *trace.Region - if debug.Traced { - packageRegion = trace.StartRegion(ctx, "package") - trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) - } + packageRegion := trace.StartRegion(ctx, "package") + trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr) if failed { - if packageRegion != nil { - packageRegion.End() - } + packageRegion.End() errored = true break } @@ -213,9 +208,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer fmt.Fprintf(stderr, "# package %s\n", name) fmt.Fprintf(stderr, "error generating code: %s\n", err) errored = true - if packageRegion != nil { - packageRegion.End() - } + packageRegion.End() continue } @@ -227,9 +220,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer filename := filepath.Join(dir, out, n) output[filename] = source } - if packageRegion != nil { - packageRegion.End() - } + packageRegion.End() } if errored { @@ -239,9 +230,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer } func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { - if debug.Traced { - defer trace.StartRegion(ctx, "parse").End() - } + defer trace.StartRegion(ctx, "parse").End() c := compiler.NewCompiler(sql, combo) if err := c.ParseCatalog(sql.Schema); err != nil { fmt.Fprintf(stderr, "# package %s\n", name) @@ -272,10 +261,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C } func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) { - var region *trace.Region - if debug.Traced { - region = trace.StartRegion(ctx, "codegen") - } + defer trace.StartRegion(ctx, "codegen").End() req := codeGenRequest(result, combo) var handler ext.Handler var out string @@ -319,8 +305,5 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re return "", nil, fmt.Errorf("missing language backend") } resp, err := handler.Generate(ctx, req) - if region != nil { - region.End() - } return out, resp, err } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 640132ca92..efba759adb 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -71,7 +71,7 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct { }) } s := Struct{ - Table: plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, + Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, Name: StructName(structName, req.Settings), Comment: table.Comment, } @@ -214,7 +214,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) c := query.Columns[i] sameName := f.Name == StructName(columnName(c, i), req.Settings) sameType := f.Type == goType(req, c) - sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) + sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false } diff --git a/internal/codegen/golang/struct.go b/internal/codegen/golang/struct.go index f72a228ae3..c1dfd5663d 100644 --- a/internal/codegen/golang/struct.go +++ b/internal/codegen/golang/struct.go @@ -9,7 +9,7 @@ import ( ) type Struct struct { - Table plugin.Identifier + Table *plugin.Identifier Name string Fields []Field Comment string diff --git a/internal/codegen/json/gen.go b/internal/codegen/json/gen.go index f481d009c6..75ab3941cf 100644 --- a/internal/codegen/json/gen.go +++ b/internal/codegen/json/gen.go @@ -11,13 +11,13 @@ import ( "github.com/kyleconroy/sqlc/internal/plugin" ) -func parseOptions(req *plugin.CodeGenRequest) (plugin.JSONCode, error) { +func parseOptions(req *plugin.CodeGenRequest) (*plugin.JSONCode, error) { if req.Settings == nil { - return plugin.JSONCode{}, nil + return new(plugin.JSONCode), nil } if req.Settings.Codegen != nil { if len(req.Settings.Codegen.Options) != 0 { - var options plugin.JSONCode + var options *plugin.JSONCode dec := ejson.NewDecoder(bytes.NewReader(req.Settings.Codegen.Options)) dec.DisallowUnknownFields() if err := dec.Decode(&options); err != nil { @@ -27,9 +27,9 @@ func parseOptions(req *plugin.CodeGenRequest) (plugin.JSONCode, error) { } } if req.Settings.Json != nil { - return *req.Settings.Json, nil + return req.Settings.Json, nil } - return plugin.JSONCode{}, nil + return new(plugin.JSONCode), nil } func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) { diff --git a/internal/debug/dump.go b/internal/debug/dump.go index 5f95a7fe43..d2e25bab7c 100644 --- a/internal/debug/dump.go +++ b/internal/debug/dump.go @@ -9,14 +9,12 @@ import ( ) var Active bool -var Traced bool var Debug opts.Debug func init() { Active = os.Getenv("SQLCDEBUG") != "" if Active { Debug = opts.DebugFromEnv() - Traced = Debug.Trace != "" } } diff --git a/internal/tracer/trace.go b/internal/tracer/trace.go index 38bf6ace9e..d0c265f1c7 100644 --- a/internal/tracer/trace.go +++ b/internal/tracer/trace.go @@ -9,18 +9,17 @@ import ( "github.com/kyleconroy/sqlc/internal/debug" ) -func Start(base context.Context) (context.Context, func(), error) { - if !debug.Traced { - return base, func() {}, nil - } - +// Start starts Go's runtime tracing facility. +// Traces will be written to the file named by [debug.Debug.Trace]. +// It also starts a new [*trace.Task] that will be stopped when the cleanup is called. +func Start(base context.Context) (_ context.Context, cleanup func(), _ error) { f, err := os.Create(debug.Debug.Trace) if err != nil { - return base, func() {}, fmt.Errorf("failed to create trace output file: %v", err) + return base, cleanup, fmt.Errorf("failed to create trace output file: %v", err) } if err := trace.Start(f); err != nil { - return base, func() {}, fmt.Errorf("failed to start trace: %v", err) + return base, cleanup, fmt.Errorf("failed to start trace: %v", err) } ctx, task := trace.NewTask(base, "sqlc")