From eebddb9325ede76ffa1853d00508da54cb5b9678 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Thu, 7 Jan 2021 10:27:03 +0100 Subject: [PATCH] feat(cli): support manifest bigger than 1k packages **User Story** As a user of the Lacework CLI, I would like to be able to submit scans of package manifests bigger than 1,000 packages, So I don't have to implement a splitting mechanism and run multiple CLI commands for a single manifest. **Implementation Details** The CLI will now check if the package manifest has more than the maximum number of packages, if so, it will split the package manifest into multiple chunks and trigger multiple API requests. **NOTE:** We disallow more than 10 parallel requests (workers), which are more than 10,000 packages on a single manifest/system. Closes https://github.com/lacework/go-sdk/issues/237 Signed-off-by: Salim Afiune Maya --- cli/cmd/package_manifest.go | 79 ++++++++++++++++++++++++++++++++ cli/cmd/package_manifest_test.go | 74 ++++++++++++++++++++++++++++++ cli/cmd/vuln_host.go | 29 ++++++++---- 3 files changed, 174 insertions(+), 8 deletions(-) diff --git a/cli/cmd/package_manifest.go b/cli/cmd/package_manifest.go index abd0610ad..02158a4a9 100644 --- a/cli/cmd/package_manifest.go +++ b/cli/cmd/package_manifest.go @@ -349,3 +349,82 @@ func removeEpochFromPkgVersion(pkgVer string) string { return pkgVer } + +// split the provided package_manifest into chucks, if the manifest +// is smaller than the provided chunk size, it will return the manifest +// as an array without modifications +func splitPackageManifest(manifest *api.PackageManifest, chunks int) []*api.PackageManifest { + if len(manifest.OsPkgInfoList) <= chunks { + return []*api.PackageManifest{manifest} + } + + var batches []*api.PackageManifest + for i := 0; i < len(manifest.OsPkgInfoList); i += chunks { + batch := manifest.OsPkgInfoList[i:min(i+chunks, len(manifest.OsPkgInfoList))] + cli.Log.Infow("manifest batch", "total_packages", len(batch)) + batches = append(batches, &api.PackageManifest{OsPkgInfoList: batch}) + } + return batches +} + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +// fan-out a number of package manifests into multiple requests all at once +func fanOutHostScans(manifests ...*api.PackageManifest) (api.HostVulnScanPkgManifestResponse, error) { + var ( + resCh = make(chan api.HostVulnScanPkgManifestResponse) + errCh = make(chan error) + workers = len(manifests) + fanInRes = api.HostVulnScanPkgManifestResponse{} + ) + + // disallow more than 10 workers which are 10 calls all at once, + // the API has a rate-limit of 10 calls per hour, per access key + if workers > 10 { + return fanInRes, errors.New("limit of packages exceeded") + } + + // for every manifest, create a new worker, that is, spawn + // a new goroutine that will send the manifest to scan + for n, m := range manifests { + if m == nil { + workers-- + continue + } + cli.Log.Infow("spawn worker", "number", n+1) + go func(manifest *api.PackageManifest, c *cliState) { + response, err := c.LwApi.Vulnerabilities.Host.Scan(manifest) + if err != nil { + errCh <- err + return + } + resCh <- response + }(m, cli) + } + + // lock the main process and read both, the error and response + // channels, if we receive at least one error, we will stop + // processing and bubble up the error to the caller + for processed := 0; processed < workers; processed++ { + select { + case err := <-errCh: + // end processing as soon as we receive the first error + return fanInRes, err + case res := <-resCh: + // processing scan + cli.Log.Infow("processing worker response", "n", processed+1) + fanInRes.Vulns = append(fanInRes.Vulns, res.Vulns...) + if res.Message != "" && res.Message != fanInRes.Message { + fanInRes.Message = res.Message + } + fanInRes.Ok = res.Ok + } + } + + return fanInRes, nil +} diff --git a/cli/cmd/package_manifest_test.go b/cli/cmd/package_manifest_test.go index e7f31e74c..a949e0f0a 100644 --- a/cli/cmd/package_manifest_test.go +++ b/cli/cmd/package_manifest_test.go @@ -19,6 +19,7 @@ package cmd import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -91,3 +92,76 @@ func TestRemoveEpochFromPkgVersion(t *testing.T) { "version", removeEpochFromPkgVersion("epoch:version")) } + +func TestSplitPackageManifest(t *testing.T) { + cases := []struct { + chunks int + size int + expectedSize int + }{ + {expectedSize: 100, + size: 500, + chunks: 5}, + {expectedSize: 45, + size: 45000, + chunks: 1000}, + {expectedSize: 50, + size: 100, + chunks: 2}, + {expectedSize: 2, + size: 1001, + chunks: 1000}, + {expectedSize: 28, + size: 55000, + chunks: 2000}, + {expectedSize: 1, + size: 123, + chunks: 1000}, + } + for i, kase := range cases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + manifest := &api.PackageManifest{ + OsPkgInfoList: make([]api.OsPkgInfo, kase.size), + } + subject := splitPackageManifest(manifest, kase.chunks) + assert.Equal(t, kase.expectedSize, len(subject)) + }) + } +} + +func TestFanOutHostScans(t *testing.T) { + subject, err := fanOutHostScans() + assert.Nil(t, err) + assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject) + + subject, err = fanOutHostScans(nil) + assert.Nil(t, err) + assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject) + + // more than 10 morkers should return an error + multiManifests := make([]*api.PackageManifest, 11) + subject, err = fanOutHostScans(multiManifests...) + if assert.NotNil(t, err) { + assert.Contains(t, err.Error(), + "limit of packages exceeded", + ) + } + assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject) + + // mock the api client + client, err := api.NewClient("test") + assert.Nil(t, err) + client.Vulnerabilities = api.NewVulnerabilityService(client) + cli.LwApi = client + defer func() { + cli.LwApi = nil + }() + + subject, err = fanOutHostScans(&api.PackageManifest{}) + if assert.NotNil(t, err) { + assert.Contains(t, err.Error(), + "unable to generate access token: auth keys missing", + ) + } + assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject) +} diff --git a/cli/cmd/vuln_host.go b/cli/cmd/vuln_host.go index 7611b88f8..89c19c017 100644 --- a/cli/cmd/vuln_host.go +++ b/cli/cmd/vuln_host.go @@ -35,6 +35,9 @@ import ( ) var ( + // the maximum number of packages per scan request + manifestPkgsCap = 1000 + // the package manifest file pkgManifestFile string @@ -100,19 +103,19 @@ To generate a package-manifest from the local host and scan it automatically: if len(args) != 0 && args[0] != "" { pkgManifestBytes = []byte(args[0]) - cli.Log.Infow("package manifest loaded from arguments", "raw", args[0]) + cli.Log.Debugw("package manifest loaded from arguments", "raw", args[0]) } else if pkgManifestFile != "" { pkgManifestBytes, err = ioutil.ReadFile(pkgManifestFile) if err != nil { return errors.Wrap(err, "unable to read file") } - cli.Log.Infow("package manifest loaded from file", "raw", string(pkgManifestBytes)) + cli.Log.Debugw("package manifest loaded from file", "raw", string(pkgManifestBytes)) } else if pkgManifestLocal { pkgManifest, err = cli.GeneratePackageManifest() if err != nil { return errors.Wrap(err, "unable to generate package manifest") } - cli.Log.Infow("package manifest generated from localhost", "raw", pkgManifest) + cli.Log.Debugw("package manifest generated from localhost", "raw", pkgManifest) } else { // avoid asking for a confirmation before launching the editor var content string @@ -125,7 +128,7 @@ To generate a package-manifest from the local host and scan it automatically: return errors.Wrap(err, "unable to load package manifest from editor") } pkgManifestBytes = []byte(content) - cli.Log.Infow("package manifest loaded via editor", "raw", content) + cli.Log.Debugw("package manifest loaded via editor", "raw", content) } if len(pkgManifestBytes) != 0 { @@ -135,10 +138,20 @@ To generate a package-manifest from the local host and scan it automatically: } } - // TODO @afiune check if the package manifest has more than - // 1k packages, if so, make multiple API requests - - response, err := cli.LwApi.Vulnerabilities.Host.Scan(pkgManifest) + cli.StartProgress(" Scanning packages...") + cli.Log.Infow("manifest", "total_packages", len(pkgManifest.OsPkgInfoList)) + var response api.HostVulnScanPkgManifestResponse + // check if the package manifest has more than the maximum + // number of packages, if so, make multiple API requests + if len(pkgManifest.OsPkgInfoList) >= manifestPkgsCap { + cli.Log.Infow("manifest over the limit, splitting up") + response, err = fanOutHostScans( + splitPackageManifest(pkgManifest, manifestPkgsCap)..., + ) + } else { + response, err = cli.LwApi.Vulnerabilities.Host.Scan(pkgManifest) + } + cli.StopProgress() if err != nil { return errors.Wrap(err, "unable to request an on-demand host vulnerability scan") }