Skip to content

Commit

Permalink
feat: use sub-command for plugins (#61)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: the CLI command and arguments was changed

See the discussion:
#20 (comment)

I think we don't need to support running *2ms* for multiple plugins at
once. It is a rare case and it is confusing the command line arguments.

Instead, I'm proposing using *SubCommand* for each plugin.

---------

Co-authored-by: Jossef Harush Kadouri <[email protected]>
  • Loading branch information
Baruch Odem (Rothkoff) and jossef authored May 14, 2023
1 parent 842cec6 commit 59a5c48
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:

- run: make build
- name: Run docker and check its output
run: if docker run -t checkmarx/2ms:latest | grep "no scan plugin initialized"; then
run: if docker run -t checkmarx/2ms:latest --version | grep "2ms version"; then
echo "Docker ran as expected";
else
echo "Docker did not run as expected";
Expand Down
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@ save: build
docker save $(image_name) > $(image_file_name)

run:
docker run -it $(image_name) $(ARGS)
docker run -it $(image_name) $(ARGS)

# To run golangci-lint, you need to install it first: https://golangci-lint.run/usage/install/#local-installation
lint:
golangci-lint run -v -E gofmt --timeout=5m
lint-fix:
golangci-lint run -v -E gofmt --fix --timeout=5m
81 changes: 31 additions & 50 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"fmt"
"os"
"strings"

Expand All @@ -18,21 +19,29 @@ import (

const timeSleepInterval = 50

var Version = "0.0.0"

var rootCmd = &cobra.Command{
Use: "2ms",
Short: "2ms Secrets Detection",
Run: execute,
Version: Version,
}

var Version = ""

var allPlugins = []plugins.IPlugin{
&plugins.ConfluencePlugin{},
&plugins.DiscordPlugin{},
&plugins.RepositoryPlugin{},
}

var channels = plugins.Channels{
Items: make(chan plugins.Item),
Errors: make(chan error),
WaitGroup: &sync.WaitGroup{},
}

var report = reporting.Init()
var secretsChan = make(chan reporting.Secret)

func initLog() {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
ll, err := rootCmd.Flags().GetString("log-level")
Expand All @@ -59,18 +68,21 @@ func initLog() {

func Execute() {
cobra.OnInitialize(initLog)
rootCmd.Flags().BoolP("all", "", true, "scan all plugins")
rootCmd.Flags().StringSlice("tags", []string{"all"}, "select rules to be applied")
rootCmd.PersistentFlags().BoolP("all", "", true, "scan all plugins")
rootCmd.PersistentFlags().StringSlice("tags", []string{"all"}, "select rules to be applied")
rootCmd.PersistentFlags().StringP("log-level", "", "info", "log level (trace, debug, info, warn, error, fatal)")

rootCmd.PersistentPreRun = preRun
rootCmd.PersistentPostRun = postRun

for _, plugin := range allPlugins {
err := plugin.DefineCommandLineArgs(rootCmd)
subCommand, err := plugin.DefineCommand(channels)
if err != nil {
log.Fatal().Msg(err.Error())
log.Fatal().Msg(fmt.Sprintf("error while defining command for plugin %s: %s", plugin.GetName(), err.Error()))
}
rootCmd.AddCommand(subCommand)
}

rootCmd.PersistentFlags().StringP("log-level", "", "info", "log level (trace, debug, info, warn, error, fatal)")

if err := rootCmd.Execute(); err != nil {
log.Fatal().Msg(err.Error())
}
Expand All @@ -90,7 +102,7 @@ func validateTags(tags []string) {
}
}

func execute(cmd *cobra.Command, args []string) {
func preRun(cmd *cobra.Command, args []string) {
tags, err := cmd.Flags().GetStringSlice("tags")
if err != nil {
log.Fatal().Msg(err.Error())
Expand All @@ -99,59 +111,29 @@ func execute(cmd *cobra.Command, args []string) {
validateTags(tags)

secrets := secrets.Init(tags)
report := reporting.Init()

var itemsChannel = make(chan plugins.Item)
var secretsChannel = make(chan reporting.Secret)
var errorsChannel = make(chan error)

var wg sync.WaitGroup

// -------------------------------------
// Get content from plugins
pluginsInitialized := 0
for _, plugin := range allPlugins {
err := plugin.Initialize(cmd)
if err != nil {
log.Error().Msg(err.Error())
continue
}
pluginsInitialized += 1
}

if pluginsInitialized == 0 {
log.Fatal().Msg("no scan plugin initialized. At least one plugin must be initialized to proceed. Stopping")
os.Exit(1)
}

for _, plugin := range allPlugins {
if !plugin.IsEnabled() {
continue
}

wg.Add(1)
go plugin.GetItems(itemsChannel, errorsChannel, &wg)
}

go func() {
for {
select {
case item := <-itemsChannel:
case item := <-channels.Items:
report.TotalItemsScanned++
wg.Add(1)
go secrets.Detect(secretsChannel, item, &wg)
case secret := <-secretsChannel:
channels.WaitGroup.Add(1)
go secrets.Detect(secretsChan, item, channels.WaitGroup)
case secret := <-secretsChan:
report.TotalSecretsFound++
report.Results[secret.ID] = append(report.Results[secret.ID], secret)
case err, ok := <-errorsChannel:
case err, ok := <-channels.Errors:
if !ok {
return
}
log.Fatal().Msg(err.Error())
}
}
}()
wg.Wait()
}

func postRun(cmd *cobra.Command, args []string) {
channels.WaitGroup.Wait()

// Wait for last secret to be added to report
time.Sleep(time.Millisecond * timeSleepInterval)
Expand All @@ -170,5 +152,4 @@ func execute(cmd *cobra.Command, args []string) {
} else {
os.Exit(0)
}

}
85 changes: 49 additions & 36 deletions plugins/confluence.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package plugins

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
Expand All @@ -14,11 +13,11 @@ import (
)

const (
argConfluence = "confluence"
argConfluenceSpaces = "confluence-spaces"
argConfluenceUsername = "confluence-username"
argConfluenceToken = "confluence-token"
argConfluenceHistory = "history"
argUrl = "url"
argSpaces = "spaces"
argUsername = "username"
argToken = "token"
argHistory = "history"
confluenceDefaultWindow = 25
confluenceMaxRequests = 500
)
Expand All @@ -32,62 +31,76 @@ type ConfluencePlugin struct {
History bool
}

func (p *ConfluencePlugin) IsEnabled() bool {
return p.Enabled
func (p *ConfluencePlugin) GetName() string {
return "confluence"
}

func (p *ConfluencePlugin) GetCredentials() (string, string) {
return p.Username, p.Token
}

func (p *ConfluencePlugin) DefineCommandLineArgs(cmd *cobra.Command) error {
flags := cmd.Flags()
flags.StringP(argConfluence, "", "", "scan confluence url")
flags.StringArray(argConfluenceSpaces, []string{}, "confluence spaces")
flags.StringP(argConfluenceUsername, "", "", "confluence username or email")
flags.StringP(argConfluenceToken, "", "", "confluence token")
flags.BoolP(argConfluenceHistory, "", false, "scan pages history")
return nil
func (p *ConfluencePlugin) DefineCommand(channels Channels) (*cobra.Command, error) {
var confluenceCmd = &cobra.Command{
Use: p.GetName(),
Short: "Scan confluence",
}

flags := confluenceCmd.Flags()
flags.StringP(argUrl, "", "", "confluence url")
flags.StringArray(argSpaces, []string{}, "confluence spaces")
flags.StringP(argUsername, "", "", "confluence username or email")
flags.StringP(argToken, "", "", "confluence token")
flags.BoolP(argHistory, "", false, "scan pages history")
err := confluenceCmd.MarkFlagRequired(argUrl)
if err != nil {
return nil, fmt.Errorf("error while marking '%s' flag as required: %w", argUrl, err)
}

confluenceCmd.Run = func(cmd *cobra.Command, args []string) {
err := p.Initialize(cmd)
if err != nil {
channels.Errors <- fmt.Errorf("error while initializing confluence plugin: %w", err)
return
}

p.GetItems(channels.Items, channels.Errors, channels.WaitGroup)
}

return confluenceCmd, nil
}

func (p *ConfluencePlugin) Initialize(cmd *cobra.Command) error {
flags := cmd.Flags()
confluenceUrl, _ := flags.GetString(argConfluence)
if confluenceUrl == "" {
return errors.New("confluence URL arg is missing. Plugin initialization failed")
url, err := flags.GetString(argUrl)
if err != nil {
return fmt.Errorf("error while getting '%s' flag value: %w", argUrl, err)
}

confluenceUrl = strings.TrimRight(confluenceUrl, "/")
url = strings.TrimRight(url, "/")

confluenceSpaces, _ := flags.GetStringArray(argConfluenceSpaces)
confluenceUsername, _ := flags.GetString(argConfluenceUsername)
confluenceToken, _ := flags.GetString(argConfluenceToken)
runHistory, _ := flags.GetBool(argConfluenceHistory)
spaces, _ := flags.GetStringArray(argSpaces)
username, _ := flags.GetString(argUsername)
token, _ := flags.GetString(argToken)
runHistory, _ := flags.GetBool(argHistory)

if confluenceUsername == "" || confluenceToken == "" {
if username == "" || token == "" {
log.Warn().Msg("confluence credentials were not provided. The scan will be made anonymously only for the public pages")
}

p.Token = confluenceToken
p.Username = confluenceUsername
p.URL = confluenceUrl
p.Spaces = confluenceSpaces
p.Enabled = true
p.Token = token
p.Username = username
p.URL = url
p.Spaces = spaces
p.History = runHistory
p.Limit = make(chan struct{}, confluenceMaxRequests)
return nil
}

func (p *ConfluencePlugin) GetItems(items chan Item, errs chan error, wg *sync.WaitGroup) {
defer wg.Done()

go p.getSpacesItems(items, errs, wg)
wg.Add(1)
p.getSpacesItems(items, errs, wg)
}

func (p *ConfluencePlugin) getSpacesItems(items chan Item, errs chan error, wg *sync.WaitGroup) {
defer wg.Done()

spaces, err := p.getSpaces()
if err != nil {
errs <- err
Expand Down
Loading

0 comments on commit 59a5c48

Please sign in to comment.