diff --git a/e2e/cli_test.go b/e2e/cli_test.go index 24bce8b55e1..1772421494f 100644 --- a/e2e/cli_test.go +++ b/e2e/cli_test.go @@ -389,7 +389,7 @@ var tests = []struct { }, }, validation: func(outputText string) bool { - match, _ := regexp.MatchString(`Total CPU usage for inspect: \d+`, outputText) + match, _ := regexp.MatchString(`Total CPU usage for start_scan: \d+`, outputText) return match }, wantStatus: []int{50}, @@ -405,7 +405,7 @@ var tests = []struct { }, }, validation: func(outputText string) bool { - match, _ := regexp.MatchString(`Total MEM usage for inspect: \d+`, outputText) + match, _ := regexp.MatchString(`Total MEM usage for start_scan: \d+`, outputText) return match }, wantStatus: []int{50}, diff --git a/internal/console/helpers/helpers.go b/internal/console/helpers/helpers.go index 899d150dc47..40950c8a3cd 100644 --- a/internal/console/helpers/helpers.go +++ b/internal/console/helpers/helpers.go @@ -31,11 +31,12 @@ var reportGenerators = map[string]func(path, filename string, body interface{}) // ProgressBar represents a Progress // Writer is the writer output for progress bar type ProgressBar struct { - Writer io.Writer - label string - space int - total float64 - progress chan float64 + Writer io.Writer + label string + space int + total float64 + currentProgress float64 + progress chan float64 } // Printer wil print console output with colors @@ -80,13 +81,14 @@ func (p *ProgressBar) Start(wg *sync.WaitGroup) { const hundredPercent = 100 formmatingString := "\r" + p.label + "[%s %4.1f%% %s]" for { - currentProgress, ok := <-p.progress - if !ok || currentProgress >= p.total { + newProgress, ok := <-p.progress + p.currentProgress += newProgress + if !ok || p.currentProgress >= p.total { fmt.Fprintf(p.Writer, formmatingString, strings.Repeat("=", p.space), 100.0, strings.Repeat("=", p.space)) break } - percentage := currentProgress / p.total * hundredPercent + percentage := p.currentProgress / p.total * hundredPercent convertedPercentage := int(math.Round(float64(p.space+p.space) / hundredPercent * math.Round(percentage))) if percentage >= hundredPercent/2 { firstHalfPercentage = strings.Repeat("=", p.space) diff --git a/internal/console/helpers/helpers_test.go b/internal/console/helpers/helpers_test.go index 38dc65c0f60..5055080b455 100644 --- a/internal/console/helpers/helpers_test.go +++ b/internal/console/helpers/helpers_test.go @@ -127,10 +127,10 @@ func TestProgressBar(t *testing.T) { } go progressBar.Start(&wg) if tt.shouldCheckOutput { - for i := 0; i < 101; i++ { - progress <- float64(i) + for i := 0; i < 100; i++ { + progress <- float64(1) } - progress <- float64(100) + progress <- float64(1) } wg.Wait() splittedOut := strings.Split(out.String(), "\r") diff --git a/internal/console/scan.go b/internal/console/scan.go index 7cd52917cea..eaf7a8f0f58 100644 --- a/internal/console/scan.go +++ b/internal/console/scan.go @@ -29,6 +29,7 @@ import ( yamlParser "github.com/Checkmarx/kics/pkg/parser/yaml" "github.com/Checkmarx/kics/pkg/resolver" "github.com/Checkmarx/kics/pkg/resolver/helm" + "github.com/Checkmarx/kics/pkg/scanner" "github.com/getsentry/sentry-go" "github.com/pkg/errors" "github.com/rs/zerolog/log" @@ -397,7 +398,7 @@ func analyzePaths(paths, types, exclude []string) (typesRes, excludeRes []string func createService(inspector *engine.Inspector, t kics.Tracker, store kics.Storage, - querySource source.FilesystemSource) (*kics.Service, error) { + querySource source.FilesystemSource) ([]*kics.Service, error) { filesSource, err := getFileSystemSourceProvider() if err != nil { return nil, err @@ -421,14 +422,19 @@ func createService(inspector *engine.Inspector, return nil, err } - return &kics.Service{ - SourceProvider: filesSource, - Storage: store, - Parser: combinedParser, - Inspector: inspector, - Tracker: t, - Resolver: combinedResolver, - }, nil + services := make([]*kics.Service, 0, len(combinedParser)) + + for _, parser := range combinedParser { + services = append(services, &kics.Service{ + SourceProvider: filesSource, + Storage: store, + Parser: parser, + Inspector: inspector, + Tracker: t, + Resolver: combinedResolver, + }) + } + return services, nil } func scan(changedDefaultQueryPath bool) error { @@ -475,15 +481,15 @@ func scan(changedDefaultQueryPath bool) error { return err } - service, err := createService(inspector, t, store, *querySource) + services, err := createService(inspector, t, store, *querySource) if err != nil { log.Err(err) return err } - if scanErr := service.StartScan(ctx, scanID, noProgress); scanErr != nil { - log.Err(scanErr) - return scanErr + if err = scanner.StartScan(ctx, scanID, noProgress, services); err != nil { + log.Err(err) + return err } results, err := store.GetVulnerabilities(ctx, scanID) @@ -523,7 +529,7 @@ func getSummary(t *tracker.CITracker, results []model.Vulnerability) model.Summa ScannedFiles: t.FoundFiles, ParsedFiles: t.ParsedFiles, TotalQueries: t.LoadedQueries, - FailedToExecuteQueries: t.LoadedQueries - t.ExecutedQueries, + FailedToExecuteQueries: t.ExecutingQueries - t.ExecutedQueries, FailedSimilarityID: t.FailedSimilarityID, } @@ -581,10 +587,12 @@ func printOutput(outputPath, filename string, body interface{}, formats []string func gracefulShutdown() { c := make(chan os.Signal) signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { + showErrors := consoleHelpers.ShowError("errors") + interruptCode := constants.SignalInterruptCode + go func(showErrors bool, interruptCode int) { <-c - if consoleHelpers.ShowError("errors") { - os.Exit(constants.SignalInterruptCode) + if showErrors { + os.Exit(interruptCode) } - }() + }(showErrors, interruptCode) } diff --git a/internal/tracker/ci.go b/internal/tracker/ci.go index c499bfe38e9..7c1a1fc7d44 100644 --- a/internal/tracker/ci.go +++ b/internal/tracker/ci.go @@ -10,6 +10,7 @@ import ( // and how many files were found and executed type CITracker struct { LoadedQueries int + ExecutingQueries int ExecutedQueries int FoundFiles int ParsedFiles int @@ -39,6 +40,11 @@ func (c *CITracker) TrackQueryLoad(queryAggregation int) { c.LoadedQueries += queryAggregation } +// TrackQueryExecuting adds a executing queries +func (c *CITracker) TrackQueryExecuting(queryAggregation int) { + c.ExecutingQueries += queryAggregation +} + // TrackQueryExecution adds a query executed func (c *CITracker) TrackQueryExecution(queryAggregation int) { c.ExecutedQueries += queryAggregation diff --git a/pkg/engine/inspector.go b/pkg/engine/inspector.go index 42d6eae9b18..2530a872375 100644 --- a/pkg/engine/inspector.go +++ b/pkg/engine/inspector.go @@ -3,12 +3,9 @@ package engine import ( "context" "encoding/json" - "fmt" - "io" - "sync" + "strings" "time" - consoleHelpers "github.com/Checkmarx/kics/internal/console/helpers" "github.com/Checkmarx/kics/internal/metrics" "github.com/Checkmarx/kics/pkg/detector" "github.com/Checkmarx/kics/pkg/detector/docker" @@ -53,6 +50,7 @@ type VulnerabilityBuilder func(ctx *QueryContext, tracker Tracker, v interface{} // GetOutputLines returns the number of lines to be displayed in results outputs type Tracker interface { TrackQueryLoad(queryAggregation int) + TrackQueryExecuting(queryAggregation int) TrackQueryExecution(queryAggregation int) FailedDetectLine() FailedComputeSimilarityID() @@ -158,6 +156,7 @@ func NewInspector( }) } } + failedQueries := make(map[string]error) queriesNumber := sumAllAggregatedQueries(opaQueries) @@ -193,22 +192,15 @@ func sumAllAggregatedQueries(opaQueries []*preparedQuery) int { return sum } -func startProgressBar(hideProgress bool, total int, wg *sync.WaitGroup, progressChannel chan float64) { - wg.Add(1) - progressBar := consoleHelpers.NewProgressBar("Executing queries: ", 10, float64(total), progressChannel) - if hideProgress { - progressBar.Writer = io.Discard - } - go progressBar.Start(wg) -} - // Inspect scan files and return the a list of vulnerabilities found on the process func (c *Inspector) Inspect( ctx context.Context, scanID string, files model.FileMetadatas, hideProgress bool, - baseScanPaths []string) ([]model.Vulnerability, error) { + baseScanPaths []string, + platforms []string, + currentQuery chan<- float64) ([]model.Vulnerability, error) { log.Debug().Msg("engine.Inspect()") combinedFiles := files.Combine() @@ -219,12 +211,9 @@ func (c *Inspector) Inspect( var vulnerabilities []model.Vulnerability vulnerabilities = make([]model.Vulnerability, 0) - currentQuery := make(chan float64, 1) - var wg sync.WaitGroup - startProgressBar(hideProgress, len(c.queries), &wg, currentQuery) - for idx, query := range c.queries { + for _, query := range c.getQueriesByPlat(platforms) { if !hideProgress { - currentQuery <- float64(idx) + currentQuery <- float64(1) } vuls, err := c.doRun(&QueryContext{ @@ -250,12 +239,32 @@ func (c *Inspector) Inspect( c.tracker.TrackQueryExecution(query.metadata.Aggregation) } - close(currentQuery) - wg.Wait() - fmt.Println("\r") + return vulnerabilities, nil } +// LenQueriesByPlat returns the number of queries by platforms +func (c *Inspector) LenQueriesByPlat(platforms []string) int { + count := 0 + for _, query := range c.queries { + if contains(platforms, query.metadata.Platform) { + c.tracker.TrackQueryExecuting(query.metadata.Aggregation) + count++ + } + } + return count +} + +func (c *Inspector) getQueriesByPlat(platforms []string) []*preparedQuery { + queries := make([]*preparedQuery, 0) + for _, query := range c.queries { + if contains(platforms, query.metadata.Platform) { + queries = append(queries, query) + } + } + return queries +} + // EnableCoverageReport enables the flag to create a coverage report func (c *Inspector) EnableCoverageReport() { c.enableCoverageReport = true @@ -360,3 +369,20 @@ func (c *Inspector) decodeQueryResults(ctx *QueryContext, results rego.ResultSet return vulnerabilities, nil } + +// contains is a simple method to check if a slice +// contains an entry +func contains(s []string, e string) bool { + if e == "common" { + return true + } + if e == "k8s" { + e = "kubernetes" + } + for _, a := range s { + if strings.EqualFold(a, e) { + return true + } + } + return false +} diff --git a/pkg/engine/inspector_test.go b/pkg/engine/inspector_test.go index a13483cff98..2629bfdd711 100644 --- a/pkg/engine/inspector_test.go +++ b/pkg/engine/inspector_test.go @@ -289,6 +289,7 @@ func TestInspect(t *testing.T) { //nolint for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + currentQuery := make(chan float64) c := &Inspector{ queries: tt.fields.queries, vb: tt.fields.vb, @@ -299,7 +300,8 @@ func TestInspect(t *testing.T) { //nolint detector: inspDetector, queryExecTimeout: time.Duration(60) * time.Second, } - got, err := c.Inspect(tt.args.ctx, tt.args.scanID, tt.args.files, true, []string{filepath.FromSlash("assets/queries/")}) + got, err := c.Inspect(tt.args.ctx, tt.args.scanID, tt.args.files, + true, []string{filepath.FromSlash("assets/queries/")}, []string{""}, currentQuery) if tt.wantErr { if err == nil { t.Errorf("Inspector.Inspect() = %v,\nwant %v", err, tt.want) diff --git a/pkg/engine/provider/filesystem.go b/pkg/engine/provider/filesystem.go index 823131392f8..8aa653d6099 100644 --- a/pkg/engine/provider/filesystem.go +++ b/pkg/engine/provider/filesystem.go @@ -94,6 +94,9 @@ func (s *FileSystemSourceProvider) GetSources(ctx context.Context, if !fileInfo.IsDir() { c, openFileErr := openScanFile(scanPath, extensions) if openFileErr != nil { + if openFileErr == ErrNotSupportedFile { + continue + } return openFileErr } if sinkErr := sink(ctx, scanPath, c); sinkErr != nil { @@ -185,7 +188,7 @@ func (s *FileSystemSourceProvider) checkConditions(info os.FileInfo, extensions } if f, ok := s.excludes[info.Name()]; ok && containsFile(f, info) { - log.Info().Msgf("File ignored: %s", path) + log.Trace().Msgf("File ignored: %s", path) return true, nil } if !extensions.Include(filepath.Ext(path)) && !extensions.Include(filepath.Base(path)) { diff --git a/pkg/kics/service.go b/pkg/kics/service.go index ca3ddd1c861..f0e90ef9e78 100644 --- a/pkg/kics/service.go +++ b/pkg/kics/service.go @@ -3,8 +3,8 @@ package kics import ( "context" "io" + "sync" - "github.com/Checkmarx/kics/internal/metrics" "github.com/Checkmarx/kics/pkg/engine" "github.com/Checkmarx/kics/pkg/engine/provider" "github.com/Checkmarx/kics/pkg/model" @@ -48,9 +48,15 @@ type Service struct { } // StartScan executes scan over the context, using the scanID as reference -func (s *Service) StartScan(ctx context.Context, scanID string, hideProgress bool) error { +func (s *Service) StartScan( + ctx context.Context, + scanID string, + hideProgress bool, + errCh chan<- error, + wg *sync.WaitGroup, + currentQuery chan<- float64) { log.Debug().Msg("service.StartScan()") - metrics.Metric.Start("get_sources") + defer wg.Done() if err := s.SourceProvider.GetSources( ctx, s.Parser.SupportedExtensions(), @@ -61,18 +67,24 @@ func (s *Service) StartScan(ctx context.Context, scanID string, hideProgress boo return s.resolverSink(ctx, filename, scanID) }, ); err != nil { - return errors.Wrap(err, "failed to read sources") + errCh <- errors.Wrap(err, "failed to read sources") } - metrics.Metric.Stop() - metrics.Metric.Start("inspect") - vulnerabilities, err := s.Inspector.Inspect(ctx, scanID, s.files, hideProgress, s.SourceProvider.GetBasePaths()) + vulnerabilities, err := s.Inspector.Inspect( + ctx, + scanID, + s.files, + hideProgress, + s.SourceProvider.GetBasePaths(), + s.Parser.Platform, + currentQuery, + ) if err != nil { - return errors.Wrap(err, "failed to inspect files") + errCh <- errors.Wrap(err, "failed to inspect files") } - err = s.Storage.SaveVulnerabilities(ctx, vulnerabilities) - metrics.Metric.Stop() - return errors.Wrap(err, "failed to save vulnerabilities") + if err != nil { + errCh <- errors.Wrap(err, "failed to save vulnerabilities") + } } /* diff --git a/pkg/kics/service_test.go b/pkg/kics/service_test.go index b877969452d..3f756ec9f88 100644 --- a/pkg/kics/service_test.go +++ b/pkg/kics/service_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "sync" "testing" "github.com/Checkmarx/kics/internal/storage" @@ -21,13 +22,12 @@ import ( ) // TestService tests the functions [GetVulnerabilities(), GetScanSummary(),StartScan()] and all the methods called by them -func TestService(t *testing.T) { +func TestService(t *testing.T) { //nolint mockParser, mockFilesSource, mockResolver := createParserSourceProvider("../../test/fixtures/test_helm") - type fields struct { SourceProvider provider.SourceProvider Storage Storage - Parser *parser.Parser + Parser []*parser.Parser Inspector *engine.Inspector Tracker Tracker Resolver *resolver.Resolver @@ -71,43 +71,71 @@ func TestService(t *testing.T) { }, } for _, tt := range tests { - s := &Service{ - SourceProvider: tt.fields.SourceProvider, - Storage: tt.fields.Storage, - Parser: tt.fields.Parser, - Inspector: tt.fields.Inspector, - Tracker: tt.fields.Tracker, - Resolver: tt.fields.Resolver, + s := make([]*Service, 0, len(tt.fields.Parser)) + for _, parser := range tt.fields.Parser { + s = append(s, &Service{ + SourceProvider: tt.fields.SourceProvider, + Storage: tt.fields.Storage, + Parser: parser, + Inspector: tt.fields.Inspector, + Tracker: tt.fields.Tracker, + Resolver: tt.fields.Resolver, + }) } t.Run(fmt.Sprintf(tt.name+"_get_vulnerabilities"), func(t *testing.T) { - got, err := s.GetVulnerabilities(tt.args.ctx, tt.args.scanID) - if (err != nil) != tt.wantErr { - t.Errorf("Service.GetVulnerabilities() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want.vulnerabilities) { - t.Errorf("Service.GetVulnerabilities() = %v, want %v", got, tt.want) + for _, serv := range s { + got, err := serv.GetVulnerabilities(tt.args.ctx, tt.args.scanID) + if (err != nil) != tt.wantErr { + t.Errorf("Service.GetVulnerabilities() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want.vulnerabilities) { + t.Errorf("Service.GetVulnerabilities() = %v, want %v", got, tt.want) + } } }) t.Run(fmt.Sprintf(tt.name+"_get_scan_summary"), func(t *testing.T) { - got, err := s.GetScanSummary(tt.args.ctx, tt.args.scanIDs) - if (err != nil) != tt.wantErr { - t.Errorf("Service.GetScanSummary() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want.severitySummary) { - t.Errorf("Service.GetScanSummary() = %v, want %v", got, tt.want) + for _, serv := range s { + got, err := serv.GetScanSummary(tt.args.ctx, tt.args.scanIDs) + if (err != nil) != tt.wantErr { + t.Errorf("Service.GetScanSummary() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want.severitySummary) { + t.Errorf("Service.GetScanSummary() = %v, want %v", got, tt.want) + } } }) t.Run(fmt.Sprintf(tt.name+"_start_scan"), func(t *testing.T) { - if err := s.StartScan(tt.args.ctx, tt.args.scanID, true); (err != nil) != tt.wantErr { - t.Errorf("Service.StartScan() error = %v, wantErr %v", err, tt.wantErr) + var wg sync.WaitGroup + errCh := make(chan error) + wgDone := make(chan bool) + currentQuery := make(chan float64) + for _, serv := range s { + wg.Add(1) + serv.StartScan(tt.args.ctx, tt.args.scanID, true, errCh, &wg, currentQuery) + } + go func() { + defer func() { + close(currentQuery) + close(wgDone) + }() + wg.Wait() + }() + select { + case <-wgDone: + break + case err := <-errCh: + close(errCh) + if (err != nil) != tt.wantErr { + t.Errorf("Service.StartScan() error = %v, wantErr %v", err, tt.wantErr) + } } }) } } -func createParserSourceProvider(path string) (*parser.Parser, +func createParserSourceProvider(path string) ([]*parser.Parser, *provider.FileSystemSourceProvider, *resolver.Resolver) { mockParser, _ := parser.NewBuilder(). Add(&jsonParser.Parser{}). diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index acc3a99dc07..897d2bfc664 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -36,28 +36,32 @@ func (b *Builder) Add(p kindParser) *Builder { } // Build prepares parsers and associates a parser to its extension and returns it -func (b *Builder) Build(types []string) (*Parser, error) { +func (b *Builder) Build(types []string) ([]*Parser, error) { + parserSlice := make([]*Parser, 0, len(b.parsers)) var suportedTypes []string - parsers := make(map[string]kindParser, len(b.parsers)) - extensions := make(model.Extensions, len(b.parsers)) for _, parser := range b.parsers { - suportedTypes = append(suportedTypes, parser.SupportedTypes()...) + var parsers kindParser + extensions := make(model.Extensions, len(b.parsers)) + platforms := parser.SupportedTypes() + suportedTypes = append(suportedTypes, platforms...) if _, _, ok := contains(types, parser.SupportedTypes()); ok { + parsers = parser for _, ext := range parser.SupportedExtensions() { - parsers[ext] = parser extensions[ext] = struct{}{} } + parserSlice = append(parserSlice, &Parser{ + parsers: parsers, + extensions: extensions, + Platform: platforms, + }) } } if err := validateArguments(types, suportedTypes); err != nil { - return &Parser{}, err + return []*Parser{}, err } - return &Parser{ - parsers: parsers, - extensions: extensions, - }, nil + return parserSlice, nil } // ErrNotSupportedFile represents an error when a file is not supported by KICS @@ -65,8 +69,9 @@ var ErrNotSupportedFile = errors.New("unsupported file to parse") // Parser is a struct that associates a parser to its supported extensions type Parser struct { - parsers map[string]kindParser + parsers kindParser extensions model.Extensions + Platform []string } // Parse executes a parser on the fileContent and returns the file content as a Document, the file kind and @@ -76,17 +81,17 @@ func (c *Parser) Parse(filePath string, fileContent []byte) ([]model.Document, m if ext == "" { ext = filepath.Base(filePath) } - if p, ok := c.parsers[ext]; ok { - resolved, err := p.Resolve(fileContent, filePath) + if _, ok := c.extensions[ext]; ok { + resolved, err := c.parsers.Resolve(fileContent, filePath) if err != nil { return nil, "", err } - obj, err := p.Parse(filePath, *resolved) + obj, err := c.parsers.Parse(filePath, *resolved) if err != nil { return nil, "", err } - return obj, p.GetKind(), nil + return obj, c.parsers.GetKind(), nil } return nil, "", ErrNotSupportedFile diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index f6bbacc6efd..c611a9b649b 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -16,36 +16,51 @@ import ( func TestParser_Parse(t *testing.T) { p := initilizeBuilder() - docs, kind, err := p.Parse("test.json", []byte(` + for _, parser := range p { + if _, ok := parser.extensions[".json"]; !ok { + continue + } + docs, kind, err := parser.Parse("test.json", []byte(` { "martin": { - "name": "Martin D'vloper" + "name": "CxBraga" } } `)) - require.NoError(t, err) - require.Len(t, docs, 1) - require.Contains(t, docs[0], "martin") - require.Equal(t, model.KindJSON, kind) + require.NoError(t, err) + require.Len(t, docs, 1) + require.Contains(t, docs[0], "martin") + require.Equal(t, model.KindJSON, kind) + } - docs, kind, err = p.Parse("test.yaml", []byte(` + for _, parser := range p { + if _, ok := parser.extensions[".yaml"]; !ok { + continue + } + docs, kind, err := parser.Parse("test.yaml", []byte(` martin: - name: Martin D'vloper + name: CxBraga `)) - require.NoError(t, err) - require.Len(t, docs, 1) - require.Contains(t, docs[0], "martin") - require.Equal(t, model.KindYAML, kind) + require.NoError(t, err) + require.Len(t, docs, 1) + require.Contains(t, docs[0], "martin") + require.Equal(t, model.KindYAML, kind) + } - docs, kind, err = p.Parse("Dockerfile", []byte(` - FROM foo - COPY . / - RUN echo hello + for _, parser := range p { + if _, ok := parser.extensions[".dockerfile"]; !ok { + continue + } + docs, kind, err := parser.Parse("Dockerfile", []byte(` +FROM foo +COPY . / +RUN echo hello `)) - require.NoError(t, err) - require.Len(t, docs, 1) - require.Equal(t, model.KindDOCKER, kind) + require.NoError(t, err) + require.Len(t, docs, 1) + require.Equal(t, model.KindDOCKER, kind) + } } // TestParser_Empty tests the functions [Parse()] and all the methods called by them (tests an empty parser) @@ -55,18 +70,26 @@ func TestParser_Empty(t *testing.T) { if err != nil { t.Errorf("Error building parser: %s", err) } - doc, kind, err := p.Parse("test.json", nil) - require.Nil(t, doc) - require.Equal(t, model.FileKind(""), kind) - require.Error(t, err) - require.Equal(t, ErrNotSupportedFile, err) + for _, parser := range p { + doc, kind, err := parser.Parse("test.json", nil) + require.Nil(t, doc) + require.Equal(t, model.FileKind(""), kind) + require.Error(t, err) + require.Equal(t, ErrNotSupportedFile, err) + } } // TestParser_SupportedExtensions tests the functions [SupportedExtensions()] and all the methods called by them func TestParser_SupportedExtensions(t *testing.T) { p := initilizeBuilder() + extensions := make(map[string]struct{}) - extensions := p.SupportedExtensions() + for _, parser := range p { + got := parser.SupportedExtensions() + for key := range got { + extensions[key] = struct{}{} + } + } require.NotNil(t, extensions) require.Contains(t, extensions, ".json") require.Contains(t, extensions, ".tf") @@ -75,7 +98,7 @@ func TestParser_SupportedExtensions(t *testing.T) { require.Contains(t, extensions, "Dockerfile") } -func initilizeBuilder() *Parser { +func initilizeBuilder() []*Parser { bd, _ := NewBuilder(). Add(&jsonParser.Parser{}). Add(&yamlParser.Parser{}). diff --git a/pkg/parser/terraform/converter/default.go b/pkg/parser/terraform/converter/default.go index c11d6b848fa..d2f15a794b6 100644 --- a/pkg/parser/terraform/converter/default.go +++ b/pkg/parser/terraform/converter/default.go @@ -266,7 +266,7 @@ func (c *converter) convertTemplateFor(expr *hclsyntax.ForExpr) (string, error) func (c *converter) wrapExpr(expr hclsyntax.Expression) (string, error) { expression := c.rangeSource(expr.Range()) if strings.HasPrefix(expression, "var.") { - log.Warn().Msgf("Variable ${%s} value not found", expression) + log.Trace().Msgf("Variable ${%s} value not found", expression) } return "${" + expression + "}", nil } diff --git a/pkg/scanner/scanner.go b/pkg/scanner/scanner.go new file mode 100644 index 00000000000..88d19348487 --- /dev/null +++ b/pkg/scanner/scanner.go @@ -0,0 +1,69 @@ +package scanner + +import ( + "context" + "fmt" + "io" + "sync" + + consoleHelpers "github.com/Checkmarx/kics/internal/console/helpers" + "github.com/Checkmarx/kics/internal/metrics" + "github.com/Checkmarx/kics/pkg/kics" +) + +type serviceSlice []*kics.Service + +// StartScan will run concurrent scans by parser +func StartScan(ctx context.Context, scanID string, noProgress bool, services serviceSlice) error { + defer metrics.Metric.Stop() + metrics.Metric.Start("start_scan") + var wg sync.WaitGroup + wgDone := make(chan bool) + errCh := make(chan error) + currentQuery := make(chan float64, 1) + var wgProg sync.WaitGroup + total := services.GetQueriesLength() + if total != 0 { + startProgressBar(noProgress, total, &wgProg, currentQuery) + } + for _, service := range services { + wg.Add(1) + go service.StartScan(ctx, scanID, noProgress, errCh, &wg, currentQuery) + } + + go func() { + defer func() { + close(currentQuery) + close(wgDone) + fmt.Println("\r") + }() + wg.Wait() + wgProg.Wait() + }() + + select { + case <-wgDone: + break + case err := <-errCh: + close(errCh) + return err + } + return nil +} + +func (s serviceSlice) GetQueriesLength() int { + count := 0 + for _, service := range s { + count += service.Inspector.LenQueriesByPlat(service.Parser.Platform) + } + return count +} + +func startProgressBar(hideProgress bool, total int, wg *sync.WaitGroup, progressChannel chan float64) { + wg.Add(1) + progressBar := consoleHelpers.NewProgressBar("Executing queries: ", 10, float64(total), progressChannel) + if hideProgress { + progressBar.Writer = io.Discard + } + go progressBar.Start(wg) +} diff --git a/test/main_test.go b/test/main_test.go index 96fe820b433..019bfdf9481 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -126,23 +126,24 @@ func getFilesMetadatasWithContent(t testing.TB, filePath string, content []byte) combinedParser := getCombinedParser() files := make(model.FileMetadatas, 0) - parsedDocuments, kind, err := combinedParser.Parse(filePath, content) - require.NoError(t, err) - for _, document := range parsedDocuments { - files = append(files, model.FileMetadata{ - ID: uuid.NewString(), - ScanID: scanID, - Document: document, - OriginalData: string(content), - Kind: kind, - FileName: filePath, - }) + for _, parser := range combinedParser { + parsedDocuments, kind, err := parser.Parse(filePath, content) + for _, document := range parsedDocuments { + require.NoError(t, err) + files = append(files, model.FileMetadata{ + ID: uuid.NewString(), + ScanID: scanID, + Document: document, + OriginalData: string(content), + Kind: kind, + FileName: filePath, + }) + } } - return files } -func getCombinedParser() *parser.Parser { +func getCombinedParser() []*parser.Parser { bd, _ := parser.NewBuilder(). Add(&jsonParser.Parser{}). Add(&yamlParser.Parser{}). diff --git a/test/queries_content_test.go b/test/queries_content_test.go index a37cfba1e00..3b5d52a538e 100644 --- a/test/queries_content_test.go +++ b/test/queries_content_test.go @@ -248,7 +248,15 @@ func testQueryHasGoodReturnParams(t *testing.T, entry queryEntry) { inspector.EnableCoverageReport() - _, err = inspector.Inspect(ctx, scanID, getFileMetadatas(t, entry.PositiveFiles(t)), true, []string{BaseTestsScanPath}) + platforms := []string{"Ansible", "CloudFormation", "Kubernetes", "OpenAPI", "Terraform", "Dockerfile"} + currentQuery := make(chan float64) + _, err = inspector.Inspect(ctx, scanID, getFileMetadatas( + t, + entry.PositiveFiles(t)), + true, []string{BaseTestsScanPath}, + platforms, + currentQuery, + ) require.Nil(t, err) report := inspector.GetCoverageReport() diff --git a/test/queries_test.go b/test/queries_test.go index 7688e565339..88f0e2923eb 100644 --- a/test/queries_test.go +++ b/test/queries_test.go @@ -132,7 +132,17 @@ func testQuery(tb testing.TB, entry queryEntry, filesPath []string, expectedVuln require.Nil(tb, err) require.NotNil(tb, inspector) - vulnerabilities, err := inspector.Inspect(ctx, scanID, getFileMetadatas(tb, filesPath), true, []string{BaseTestsScanPath}) + platforms := []string{"Ansible", "CloudFormation", "Kubernetes", "OpenAPI", "Terraform", "Dockerfile"} + currentQuery := make(chan float64) + + vulnerabilities, err := inspector.Inspect( + ctx, + scanID, + getFileMetadatas(tb, filesPath), + true, []string{BaseTestsScanPath}, + platforms, + currentQuery, + ) require.Nil(tb, err) requireEqualVulnerabilities(tb, expectedVulnerabilities, vulnerabilities, entry) } diff --git a/test/similarity_id_test.go b/test/similarity_id_test.go index 2a122a163cb..46114a4e2e4 100644 --- a/test/similarity_id_test.go +++ b/test/similarity_id_test.go @@ -297,6 +297,8 @@ func createInspectorAndGetVulnerabilities(ctx context.Context, t testing.TB, require.Nil(t, err) require.NotNil(t, inspector) + currentQuery := make(chan float64) + vulnerabilities, err := inspector.Inspect( ctx, scanID, @@ -307,6 +309,8 @@ func createInspectorAndGetVulnerabilities(ctx context.Context, t testing.TB, ), true, []string{BaseTestsScanPath}, + []string{"Ansible", "CloudFormation", "Kubernetes", "OpenAPI", "Terraform", "Dockerfile"}, + currentQuery, ) require.Nil(t, err) return vulnerabilities