Skip to content

Commit

Permalink
Prevent cobra-init from overwriting files
Browse files Browse the repository at this point in the history
Running `cobra-cli init` or `cobra-cli add` would overwrite existing
files without warning, which could result in unexpected loss of data.
This commit modifies cobra-cli so that by default it will not
overwrite files, and adds the `--force` flag to the `init` and `add`
subcommands to allow it to overwrite files.

Closes spf13#59
  • Loading branch information
larsks committed Feb 3, 2023
1 parent 31479f1 commit a4480c1
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 44 deletions.
6 changes: 5 additions & 1 deletion cmd/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ Example: cobra add server -> resulting in a new cmd/server.go`,
cobra.CheckErr(fmt.Errorf("add needs a name for the command"))
}

force, err := cmd.Flags().GetBool("force")
cobra.CheckErr(err)

wd, err := os.Getwd()
cobra.CheckErr(err)

Expand All @@ -57,7 +60,7 @@ Example: cobra add server -> resulting in a new cmd/server.go`,
},
}

cobra.CheckErr(command.Create())
cobra.CheckErr(command.Create(force))

fmt.Printf("%s created at %s\n", command.CmdName, command.AbsolutePath)
},
Expand All @@ -67,6 +70,7 @@ Example: cobra add server -> resulting in a new cmd/server.go`,
func init() {
addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)")
addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command")
addCmd.Flags().BoolP("force", "f", false, "overwrite files")
cobra.CheckErr(addCmd.Flags().MarkDeprecated("package", "this operation has been removed."))
}

Expand Down
4 changes: 2 additions & 2 deletions cmd/add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func TestGoldenAddCmd(t *testing.T) {
}
defer os.RemoveAll(command.AbsolutePath)

assertNoErr(t, command.Project.Create())
assertNoErr(t, command.Create())
assertNoErr(t, command.Project.Create(false))
assertNoErr(t, command.Create(false))

generatedFile := fmt.Sprintf("%s/cmd/%s.go", command.AbsolutePath, command.CmdName)
goldenFile := fmt.Sprintf("testdata/%s.go.golden", command.CmdName)
Expand Down
14 changes: 10 additions & 4 deletions cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ and the appropriate structure for a Cobra-based CLI application.
Cobra init must be run inside of a go module (please run "go mod init <MODNAME>" first)
`,

Run: func(_ *cobra.Command, args []string) {
projectPath, err := initializeProject(args)
Run: func(cmd *cobra.Command, args []string) {
force, err := cmd.Flags().GetBool("force")
cobra.CheckErr(err)
projectPath, err := initializeProject(force, args)
cobra.CheckErr(err)
cobra.CheckErr(goGet("github.com/spf13/cobra"))
if viper.GetBool("useViper") {
Expand All @@ -49,7 +51,11 @@ Cobra init must be run inside of a go module (please run "go mod init <MODNAME>"
}
)

func initializeProject(args []string) (string, error) {
func init() {
initCmd.Flags().BoolP("force", "f", false, "overwrite files")
}

func initializeProject(force bool, args []string) (string, error) {
wd, err := os.Getwd()
if err != nil {
return "", err
Expand All @@ -72,7 +78,7 @@ func initializeProject(args []string) (string, error) {
AppName: path.Base(modName),
}

if err := project.Create(); err != nil {
if err := project.Create(force); err != nil {
return "", err
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestGoldenInitCmd(t *testing.T) {

viper.Set("useViper", true)
viper.Set("license", "apache")
projectPath, err := initializeProject(tt.args)
projectPath, err := initializeProject(false, tt.args)
defer func() {
if projectPath != "" {
os.RemoveAll(projectPath)
Expand Down
101 changes: 65 additions & 36 deletions cmd/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,36 @@ import (
"github.com/spf13/cobra-cli/tpl"
)

// Project contains name, license and paths to projects.
type Project struct {
// v2
PkgName string
Copyright string
AbsolutePath string
Legal License
Viper bool
AppName string
}
type (
// Project contains name, license and paths to projects.
Project struct {
// v2
PkgName string
Copyright string
AbsolutePath string
Legal License
Viper bool
AppName string
}

type Command struct {
CmdName string
CmdParent string
*Project
}
Command struct {
CmdName string
CmdParent string
*Project
}

func (p *Project) Create() error {
// check if AbsolutePath exists
if _, err := os.Stat(p.AbsolutePath); os.IsNotExist(err) {
// create directory
if err := os.Mkdir(p.AbsolutePath, 0754); err != nil {
return err
}
createFileFunc func(p *Project) error
)

var (
projectFiles = map[string]createFileFunc{
"%s/main.go": createMain,
"%s/cmd/root.go": createRootCmd,
"%s/LICENSE": createLicenseFile,
}
)

func createMain(p *Project) error {
// create main.go
mainFile, err := os.Create(fmt.Sprintf("%s/main.go", p.AbsolutePath))
if err != nil {
Expand All @@ -43,13 +47,12 @@ func (p *Project) Create() error {
defer mainFile.Close()

mainTemplate := template.Must(template.New("main").Parse(string(tpl.MainTemplate())))
err = mainTemplate.Execute(mainFile, p)
if err != nil {
return err
}
return mainTemplate.Execute(mainFile, p)
}

func createRootCmd(p *Project) error {
// create cmd/root.go
if _, err = os.Stat(fmt.Sprintf("%s/cmd", p.AbsolutePath)); os.IsNotExist(err) {
if _, err := os.Stat(fmt.Sprintf("%s/cmd", p.AbsolutePath)); os.IsNotExist(err) {
cobra.CheckErr(os.Mkdir(fmt.Sprintf("%s/cmd", p.AbsolutePath), 0751))
}
rootFile, err := os.Create(fmt.Sprintf("%s/cmd/root.go", p.AbsolutePath))
Expand All @@ -59,16 +62,10 @@ func (p *Project) Create() error {
defer rootFile.Close()

rootTemplate := template.Must(template.New("root").Parse(string(tpl.RootTemplate())))
err = rootTemplate.Execute(rootFile, p)
if err != nil {
return err
}

// create license
return p.createLicenseFile()
return rootTemplate.Execute(rootFile, p)
}

func (p *Project) createLicenseFile() error {
func createLicenseFile(p *Project) error {
data := map[string]interface{}{
"copyright": copyrightLine(),
}
Expand All @@ -82,7 +79,39 @@ func (p *Project) createLicenseFile() error {
return licenseTemplate.Execute(licenseFile, data)
}

func (c *Command) Create() error {
func (p *Project) Create(force bool) error {
// check if AbsolutePath exists
if _, err := os.Stat(p.AbsolutePath); os.IsNotExist(err) {
// create directory
if err := os.Mkdir(p.AbsolutePath, 0754); err != nil {
return err
}
}

// Check to make sure we don't overwrite things unless we have --force
if !force {
for path, _ := range projectFiles {
abspath := fmt.Sprintf(path, p.AbsolutePath)
if _, err := os.Stat(abspath); err == nil {
return fmt.Errorf("%s already exists; use --force to overwrite", abspath)
}
}
}

for _, createFunc := range projectFiles {
if err := createFunc(p); err != nil {
return err
}
}

return nil
}

func (c *Command) Create(force bool) error {
abspath := fmt.Sprintf("%s/cmd/%s.go", c.AbsolutePath, c.CmdName)
if _, err := os.Stat(abspath); err == nil && !force {
return fmt.Errorf("%s already exists; use --force to overwrite", abspath)
}
cmdFile, err := os.Create(fmt.Sprintf("%s/cmd/%s.go", c.AbsolutePath, c.CmdName))
if err != nil {
return err
Expand Down

0 comments on commit a4480c1

Please sign in to comment.