Skip to content

Commit

Permalink
cmd: Allow config file location to be specified (#863)
Browse files Browse the repository at this point in the history
* cmd: Allow config file location to be specified
  • Loading branch information
kyleconroy authored Jan 26, 2021
1 parent df30bbf commit 197ed3b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 42 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/lib/pq v1.9.0
github.com/pingcap/parser v0.0.0-20201024025010-3b2fb4b41d73
github.com/spf13/cobra v1.1.1
github.com/spf13/pflag v1.0.5
golang.org/x/tools v0.0.0-20191219041853-979b82bfef62 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v3 v3.0.0-20200121175148-a6ecf24a6d71
Expand Down
52 changes: 36 additions & 16 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
yaml "gopkg.in/yaml.v3"

"github.com/kyleconroy/sqlc/internal/config"
Expand All @@ -17,6 +18,8 @@ import (
// Do runs the command logic.
func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int {
rootCmd := &cobra.Command{Use: "sqlc", SilenceUsage: true}
rootCmd.PersistentFlags().StringP("file", "f", "", "specify an alternate config file (default: sqlc.yaml)")

rootCmd.AddCommand(checkCmd)
rootCmd.AddCommand(genCmd)
rootCmd.AddCommand(initCmd)
Expand Down Expand Up @@ -57,14 +60,18 @@ var initCmd = &cobra.Command{
Use: "init",
Short: "Create an empty sqlc.yaml settings file",
RunE: func(cmd *cobra.Command, args []string) error {
if _, err := os.Stat("sqlc.yaml"); !os.IsNotExist(err) {
file := "sqlc.yaml"
if f := cmd.Flag("file"); f != nil {
file = f.Value.String()
}
if _, err := os.Stat(file); !os.IsNotExist(err) {
return nil
}
blob, err := yaml.Marshal(config.V1GenerateSettings{Version: "1"})
if err != nil {
return err
}
return ioutil.WriteFile("sqlc.yaml", blob, 0644)
return ioutil.WriteFile(file, blob, 0644)
},
}

Expand All @@ -75,22 +82,39 @@ func ParseEnv() Env {
return Env{}
}

var genCmd = &cobra.Command{
Use: "generate",
Short: "Generate Go code from SQL",
Run: func(cmd *cobra.Command, args []string) {
stderr := cmd.ErrOrStderr()
dir, err := os.Getwd()
func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) {
if f != nil {
file := f.Value.String()
if file == "" {
fmt.Fprintln(stderr, "error parsing config: file argument is empty")
os.Exit(1)
}
abspath, err := filepath.Abs(file)
if err != nil {
fmt.Fprintf(stderr, "error parsing config: absolute file path lookup failed: %s\n", err)
os.Exit(1)
}
return filepath.Dir(abspath), filepath.Base(abspath)
} else {
wd, err := os.Getwd()
if err != nil {
fmt.Fprintln(stderr, "error parsing sqlc.json: file does not exist")
os.Exit(1)
}
return wd, ""
}
}

output, err := Generate(ParseEnv(), dir, stderr)
var genCmd = &cobra.Command{
Use: "generate",
Short: "Generate Go code from SQL",
Run: func(cmd *cobra.Command, args []string) {
stderr := cmd.ErrOrStderr()
dir, name := getConfigPath(stderr, cmd.Flag("file"))
output, err := Generate(ParseEnv(), dir, name, stderr)
if err != nil {
os.Exit(1)
}

for filename, source := range output {
os.MkdirAll(filepath.Dir(filename), 0755)
if err := ioutil.WriteFile(filename, []byte(source), 0644); err != nil {
Expand All @@ -106,12 +130,8 @@ var checkCmd = &cobra.Command{
Short: "Statically check SQL for syntax and type errors",
RunE: func(cmd *cobra.Command, args []string) error {
stderr := cmd.ErrOrStderr()
dir, err := os.Getwd()
if err != nil {
fmt.Fprintln(stderr, "error parsing sqlc.json: file does not exist")
os.Exit(1)
}
if _, err := Generate(Env{}, dir, stderr); err != nil {
dir, name := getConfigPath(stderr, cmd.Flag("file"))
if _, err := Generate(Env{}, dir, name, stderr); err != nil {
os.Exit(1)
}
return nil
Expand Down
49 changes: 27 additions & 22 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,39 @@ type outPair struct {
config.SQL
}

func Generate(e Env, dir string, stderr io.Writer) (map[string]string, error) {
var yamlMissing, jsonMissing bool
yamlPath := filepath.Join(dir, "sqlc.yaml")
jsonPath := filepath.Join(dir, "sqlc.json")
func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, error) {
configPath := ""
if filename != "" {
configPath = filepath.Join(dir, filename)
} else {
var yamlMissing, jsonMissing bool
yamlPath := filepath.Join(dir, "sqlc.yaml")
jsonPath := filepath.Join(dir, "sqlc.json")

if _, err := os.Stat(yamlPath); os.IsNotExist(err) {
yamlMissing = true
}
if _, err := os.Stat(jsonPath); os.IsNotExist(err) {
jsonMissing = true
}

if _, err := os.Stat(yamlPath); os.IsNotExist(err) {
yamlMissing = true
}
if _, err := os.Stat(jsonPath); os.IsNotExist(err) {
jsonMissing = true
}
if yamlMissing && jsonMissing {
fmt.Fprintln(stderr, "error parsing sqlc.json: file does not exist")
return nil, errors.New("config file missing")
}

if yamlMissing && jsonMissing {
fmt.Fprintln(stderr, "error parsing sqlc.json: file does not exist")
return nil, errors.New("config file missing")
}
if !yamlMissing && !jsonMissing {
fmt.Fprintln(stderr, "error: both sqlc.json and sqlc.yaml files present")
return nil, errors.New("sqlc.json and sqlc.yaml present")
}

if !yamlMissing && !jsonMissing {
fmt.Fprintln(stderr, "error: both sqlc.json and sqlc.yaml files present")
return nil, errors.New("sqlc.json and sqlc.yaml present")
configPath = yamlPath
if yamlMissing {
configPath = jsonPath
}
}

configPath := yamlPath
if yamlMissing {
configPath = jsonPath
}
base := filepath.Base(configPath)

blob, err := ioutil.ReadFile(configPath)
if err != nil {
fmt.Fprintf(stderr, "error parsing %s: file does not exist\n", base)
Expand Down
8 changes: 4 additions & 4 deletions internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestExamples(t *testing.T) {
t.Parallel()
path := filepath.Join(examples, tc)
var stderr bytes.Buffer
output, err := cmd.Generate(cmd.Env{}, path, &stderr)
output, err := cmd.Generate(cmd.Env{}, path, "", &stderr)
if err != nil {
t.Fatalf("sqlc generate failed: %s", stderr.String())
}
Expand All @@ -62,7 +62,7 @@ func BenchmarkExamples(b *testing.B) {
path := filepath.Join(examples, tc)
for i := 0; i < b.N; i++ {
var stderr bytes.Buffer
cmd.Generate(cmd.Env{}, path, &stderr)
cmd.Generate(cmd.Env{}, path, "", &stderr)
}
})
}
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestReplay(t *testing.T) {
path, _ := filepath.Abs(tc)
var stderr bytes.Buffer
expected := expectedStderr(t, path)
output, err := cmd.Generate(cmd.Env{}, path, &stderr)
output, err := cmd.Generate(cmd.Env{}, path, "", &stderr)
if len(expected) == 0 && err != nil {
t.Fatalf("sqlc generate failed: %s", stderr.String())
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func BenchmarkReplay(b *testing.B) {
path, _ := filepath.Abs(tc)
for i := 0; i < b.N; i++ {
var stderr bytes.Buffer
cmd.Generate(cmd.Env{}, path, &stderr)
cmd.Generate(cmd.Env{}, path, "", &stderr)
}
})
}
Expand Down

0 comments on commit 197ed3b

Please sign in to comment.