Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: clean up some tiny things #1966

Merged
merged 4 commits into from
Nov 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ install:
test:
go test ./...

vet:
go vet ./...

test-examples:
go test --tags=examples ./...

build-endtoend:
cd ./internal/endtoend/testdata && go build ./...

test-ci: test-examples build-endtoend
test-ci: test-examples build-endtoend vet

regen: sqlc-dev sqlc-gen-json
go run ./scripts/regenerate/
Expand Down
62 changes: 29 additions & 33 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,46 +45,49 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
rootCmd.SetOut(stdout)
rootCmd.SetErr(stderr)

ctx, cleanup, err := tracer.Start(context.Background())
if err != nil {
fmt.Printf("failed to start trace: %v\n", err)
return 1
ctx := context.Background()
if debug.Debug.Trace != "" {
tracectx, cleanup, err := tracer.Start(ctx)
if err != nil {
fmt.Printf("failed to start trace: %v\n", err)
return 1
}
ctx = tracectx
defer cleanup()
}
defer cleanup()

if err := rootCmd.ExecuteContext(ctx); err == nil {
return 0
}
if exitError, ok := err.(*exec.ExitError); ok {
return exitError.ExitCode()
if err := rootCmd.ExecuteContext(ctx); err != nil {
fmt.Fprintf(stderr, "%v\n", err)
if exitError, ok := err.(*exec.ExitError); ok {
return exitError.ExitCode()
} else {
return 1
}
}
return 1
return 0
}

var version string

var versionCmd = &cobra.Command{
Use: "version",
Short: "Print the sqlc version number",
Run: func(cmd *cobra.Command, args []string) {
if debug.Traced {
defer trace.StartRegion(cmd.Context(), "version").End()
}
RunE: func(cmd *cobra.Command, args []string) error {
defer trace.StartRegion(cmd.Context(), "version").End()
if version == "" {
fmt.Printf("%s\n", info.Version)
} else {
fmt.Printf("%s\n", version)
}
return nil
},
}

var initCmd = &cobra.Command{
Use: "init",
Short: "Create an empty sqlc.yaml settings file",
RunE: func(cmd *cobra.Command, args []string) error {
if debug.Traced {
defer trace.StartRegion(cmd.Context(), "init").End()
}
defer trace.StartRegion(cmd.Context(), "init").End()
file := "sqlc.yaml"
if f := cmd.Flag("file"); f != nil && f.Changed {
file = f.Value.String()
Expand Down Expand Up @@ -153,26 +156,23 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) {
var genCmd = &cobra.Command{
Use: "generate",
Short: "Generate Go code from SQL",
Run: func(cmd *cobra.Command, args []string) {
if debug.Traced {
defer trace.StartRegion(cmd.Context(), "generate").End()
}
RunE: func(cmd *cobra.Command, args []string) error {
defer trace.StartRegion(cmd.Context(), "generate").End()
stderr := cmd.ErrOrStderr()
dir, name := getConfigPath(stderr, cmd.Flag("file"))
output, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr)
if err != nil {
os.Exit(1)
}
if debug.Traced {
defer trace.StartRegion(cmd.Context(), "writefiles").End()
return err
}
defer trace.StartRegion(cmd.Context(), "writefiles").End()
for filename, source := range output {
os.MkdirAll(filepath.Dir(filename), 0755)
if err := os.WriteFile(filename, []byte(source), 0644); err != nil {
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
os.Exit(1)
return err
}
}
return nil
},
}

Expand All @@ -194,9 +194,7 @@ var checkCmd = &cobra.Command{
Use: "compile",
Short: "Statically check SQL for syntax and type errors",
RunE: func(cmd *cobra.Command, args []string) error {
if debug.Traced {
defer trace.StartRegion(cmd.Context(), "compile").End()
}
defer trace.StartRegion(cmd.Context(), "compile").End()
stderr := cmd.ErrOrStderr()
dir, name := getConfigPath(stderr, cmd.Flag("file"))
if _, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil {
Expand Down Expand Up @@ -239,9 +237,7 @@ var diffCmd = &cobra.Command{
Use: "diff",
Short: "Compare the generated files to the existing files",
RunE: func(cmd *cobra.Command, args []string) error {
if debug.Traced {
defer trace.StartRegion(cmd.Context(), "diff").End()
}
defer trace.StartRegion(cmd.Context(), "diff").End()
stderr := cmd.ErrOrStderr()
dir, name := getConfigPath(stderr, cmd.Flag("file"))
if err := Diff(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil {
Expand Down
5 changes: 1 addition & 4 deletions internal/cmd/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@ import (
"strings"

"github.com/cubicdaiya/gonp"
"github.com/kyleconroy/sqlc/internal/debug"
)

func Diff(ctx context.Context, e Env, dir, name string, stderr io.Writer) error {
output, err := Generate(ctx, e, dir, name, stderr)
if err != nil {
return err
}
if debug.Traced {
defer trace.StartRegion(ctx, "checkfiles").End()
}
defer trace.StartRegion(ctx, "checkfiles").End()
var errored bool

keys := make([]string, 0, len(output))
Expand Down
31 changes: 7 additions & 24 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,12 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
name = sql.Plugin.Plugin
}

var packageRegion *trace.Region
if debug.Traced {
packageRegion = trace.StartRegion(ctx, "package")
trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)
}
packageRegion := trace.StartRegion(ctx, "package")
trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)

result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr)
if failed {
if packageRegion != nil {
packageRegion.End()
}
packageRegion.End()
errored = true
break
}
Expand All @@ -213,9 +208,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
fmt.Fprintf(stderr, "# package %s\n", name)
fmt.Fprintf(stderr, "error generating code: %s\n", err)
errored = true
if packageRegion != nil {
packageRegion.End()
}
packageRegion.End()
continue
}

Expand All @@ -227,9 +220,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
filename := filepath.Join(dir, out, n)
output[filename] = source
}
if packageRegion != nil {
packageRegion.End()
}
packageRegion.End()
}

if errored {
Expand All @@ -239,9 +230,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
}

func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
if debug.Traced {
defer trace.StartRegion(ctx, "parse").End()
}
defer trace.StartRegion(ctx, "parse").End()
c := compiler.NewCompiler(sql, combo)
if err := c.ParseCatalog(sql.Schema); err != nil {
fmt.Fprintf(stderr, "# package %s\n", name)
Expand Down Expand Up @@ -272,10 +261,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
}

func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) {
var region *trace.Region
if debug.Traced {
region = trace.StartRegion(ctx, "codegen")
}
defer trace.StartRegion(ctx, "codegen").End()
req := codeGenRequest(result, combo)
var handler ext.Handler
var out string
Expand Down Expand Up @@ -319,8 +305,5 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
return "", nil, fmt.Errorf("missing language backend")
}
resp, err := handler.Generate(ctx, req)
if region != nil {
region.End()
}
return out, resp, err
}
4 changes: 2 additions & 2 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct {
})
}
s := Struct{
Table: plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name},
Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name},
Name: StructName(structName, req.Settings),
Comment: table.Comment,
}
Expand Down Expand Up @@ -214,7 +214,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
c := query.Columns[i]
sameName := f.Name == StructName(columnName(c, i), req.Settings)
sameType := f.Type == goType(req, c)
sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema)
sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema)
if !sameName || !sameType || !sameTable {
same = false
}
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

type Struct struct {
Table plugin.Identifier
Table *plugin.Identifier
Name string
Fields []Field
Comment string
Expand Down
10 changes: 5 additions & 5 deletions internal/codegen/json/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import (
"github.com/kyleconroy/sqlc/internal/plugin"
)

func parseOptions(req *plugin.CodeGenRequest) (plugin.JSONCode, error) {
func parseOptions(req *plugin.CodeGenRequest) (*plugin.JSONCode, error) {
if req.Settings == nil {
return plugin.JSONCode{}, nil
return new(plugin.JSONCode), nil
}
if req.Settings.Codegen != nil {
if len(req.Settings.Codegen.Options) != 0 {
var options plugin.JSONCode
var options *plugin.JSONCode
dec := ejson.NewDecoder(bytes.NewReader(req.Settings.Codegen.Options))
dec.DisallowUnknownFields()
if err := dec.Decode(&options); err != nil {
Expand All @@ -27,9 +27,9 @@ func parseOptions(req *plugin.CodeGenRequest) (plugin.JSONCode, error) {
}
}
if req.Settings.Json != nil {
return *req.Settings.Json, nil
return req.Settings.Json, nil
}
return plugin.JSONCode{}, nil
return new(plugin.JSONCode), nil
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
Expand Down
2 changes: 0 additions & 2 deletions internal/debug/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
)

var Active bool
var Traced bool
var Debug opts.Debug

func init() {
Active = os.Getenv("SQLCDEBUG") != ""
if Active {
Debug = opts.DebugFromEnv()
Traced = Debug.Trace != ""
}
}

Expand Down
13 changes: 6 additions & 7 deletions internal/tracer/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@ import (
"github.com/kyleconroy/sqlc/internal/debug"
)

func Start(base context.Context) (context.Context, func(), error) {
if !debug.Traced {
return base, func() {}, nil
}

// Start starts Go's runtime tracing facility.
// Traces will be written to the file named by [debug.Debug.Trace].
// It also starts a new [*trace.Task] that will be stopped when the cleanup is called.
func Start(base context.Context) (_ context.Context, cleanup func(), _ error) {
f, err := os.Create(debug.Debug.Trace)
if err != nil {
return base, func() {}, fmt.Errorf("failed to create trace output file: %v", err)
return base, cleanup, fmt.Errorf("failed to create trace output file: %v", err)
}

if err := trace.Start(f); err != nil {
return base, func() {}, fmt.Errorf("failed to start trace: %v", err)
return base, cleanup, fmt.Errorf("failed to start trace: %v", err)
}

ctx, task := trace.NewTask(base, "sqlc")
Expand Down