Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): support manifest bigger than 1k packages #285

Merged
merged 3 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ func (c *Client) URL() string {
return c.baseURL.String()
}

// ValidAuth verifies that the client has valid authentication
func (c *Client) ValidAuth() bool {
return c.auth.token != ""
}

// newID generates a new client id, this id is useful for logging purposes
// when there are more than one client running on the same machine
func newID() string {
Expand Down
5 changes: 4 additions & 1 deletion cli/cmd/honeyvent.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const (
//
// ```go
// cli.Event.Feature = featPollCtrScan
// cli.Event.FeatureData = map[string]interface{"key", "value"}
// cli.Event.AddFeatureField("key", "value")
// cli.SendHoneyvent()
// ```
//
Expand All @@ -74,6 +74,9 @@ const (

// Generate package manifest feature
featGenPkgManifest = "gen_pkg_manifest"

// Split package manifest feature
featSplitPkgManifest = "split_pkg_manifest"
)

// Honeyvent defines what a Honeycomb event looks like for the Lacework CLI
Expand Down
125 changes: 125 additions & 0 deletions cli/cmd/package_manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,128 @@ func removeEpochFromPkgVersion(pkgVer string) string {

return pkgVer
}

// split the provided package_manifest into chucks, if the manifest
// is smaller than the provided chunk size, it will return the manifest
// as an array without modifications
func splitPackageManifest(manifest *api.PackageManifest, chunks int) []*api.PackageManifest {
if len(manifest.OsPkgInfoList) <= chunks {
return []*api.PackageManifest{manifest}
}

var batches []*api.PackageManifest
for i := 0; i < len(manifest.OsPkgInfoList); i += chunks {
batch := manifest.OsPkgInfoList[i:min(i+chunks, len(manifest.OsPkgInfoList))]
cli.Log.Infow("manifest batch", "total_packages", len(batch))
batches = append(batches, &api.PackageManifest{OsPkgInfoList: batch})
}
return batches
}

func min(a, b int) int {
if a <= b {
return a
}
return b
}

// fan-out a number of package manifests into multiple requests all at once
func fanOutHostScans(manifests ...*api.PackageManifest) (api.HostVulnScanPkgManifestResponse, error) {
var (
resCh = make(chan api.HostVulnScanPkgManifestResponse)
errCh = make(chan error)
workers = len(manifests)
fanInRes = api.HostVulnScanPkgManifestResponse{}
)

// disallow more than 10 workers which are 10 calls all at once,
// the API has a rate-limit of 10 calls per hour, per access key
if workers > 10 {
return fanInRes, errors.New("limit of packages exceeded")
}

var (
err error
start = time.Now()
)
defer func() {
cli.Event.DurationMs = time.Since(start).Milliseconds()
// avoid duplicating events
if err == nil {
cli.SendHoneyvent()
}
}()

// ensure that the api client has a valid token
// before creating workers
if !cli.LwApi.ValidAuth() {
_, err = cli.LwApi.GenerateToken()
if err != nil {
return fanInRes, err
}
}

// for every manifest, create a new worker, that is, spawn
// a new goroutine that will send the manifest to scan
for n, m := range manifests {
if m == nil {
workers--
continue
}
cli.Log.Infow("spawn worker", "number", n+1)
go cli.triggerHostVulnScan(m, resCh, errCh)
}

cli.Event.AddFeatureField("workers", workers)

// lock the main process and read both, the error and response
// channels, if we receive at least one error, we will stop
// processing and bubble up the error to the caller
for processed := 0; processed < workers; processed++ {
select {
case err = <-errCh:
// end processing as soon as we receive the first error
return fanInRes, err
case res := <-resCh:
// processing scan
cli.Log.Infow("processing worker response", "n", processed+1)
cli.Event.AddFeatureField(fmt.Sprintf("worker%d_total_vulns", processed), len(res.Vulns))
mergeHostVulnScanPkgManifestResponses(&fanInRes, &res)
}
}

return fanInRes, nil
}

func mergeHostVulnScanPkgManifestResponses(to, from *api.HostVulnScanPkgManifestResponse) {
// append vulnerabilities from -> to
to.Vulns = append(to.Vulns, from.Vulns...)

// requests should always return an ok state
to.Ok = from.Ok

// store the message from the response only if it is NOT empty
// and it is different from the previous response (to)
if to.Message == "" {
to.Message = from.Message
return
}

// concatenate messages "to,from" response only if they
// are NOT empty and they are different from each other
if from.Message != "" && from.Message != to.Message {
to.Message = fmt.Sprintf("%s,%s", to.Message, from.Message)
}
}

func (c *cliState) triggerHostVulnScan(manifest *api.PackageManifest,
resCh chan<- api.HostVulnScanPkgManifestResponse,
errCh chan<- error,
) {
response, err := c.LwApi.Vulnerabilities.Host.Scan(manifest)
if err != nil {
errCh <- err
return
}
resCh <- response
}
131 changes: 131 additions & 0 deletions cli/cmd/package_manifest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package cmd

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -91,3 +92,133 @@ func TestRemoveEpochFromPkgVersion(t *testing.T) {
"version",
removeEpochFromPkgVersion("epoch:version"))
}

func TestSplitPackageManifest(t *testing.T) {
cases := []struct {
chunks int
size int
expectedSize int
}{
{expectedSize: 100,
size: 500,
chunks: 5},
{expectedSize: 45,
size: 45000,
chunks: 1000},
{expectedSize: 50,
size: 100,
chunks: 2},
{expectedSize: 2,
size: 1001,
chunks: 1000},
{expectedSize: 28,
size: 55000,
chunks: 2000},
{expectedSize: 1,
size: 123,
chunks: 1000},
}
for i, kase := range cases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
manifest := &api.PackageManifest{
OsPkgInfoList: make([]api.OsPkgInfo, kase.size),
}
subject := splitPackageManifest(manifest, kase.chunks)
assert.Equal(t, kase.expectedSize, len(subject))
})
}
}

func TestFanOutHostScans(t *testing.T) {
// mock the api client
client, err := api.NewClient("test", api.WithToken("mock"))
assert.Nil(t, err)
client.Vulnerabilities = api.NewVulnerabilityService(client)
cli.LwApi = client
defer func() {
cli.LwApi = nil
}()

subject, err := fanOutHostScans()
assert.Nil(t, err)
assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject)

subject, err = fanOutHostScans(nil)
assert.Nil(t, err)
assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject)

// more than 10 morkers should return an error
multiManifests := make([]*api.PackageManifest, 11)
subject, err = fanOutHostScans(multiManifests...)
if assert.NotNil(t, err) {
assert.Contains(t, err.Error(),
"limit of packages exceeded",
)
}
assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject)

subject, err = fanOutHostScans(&api.PackageManifest{})
if assert.NotNil(t, err) {
assert.Contains(t, err.Error(),
"[403] Forbidden", // intentional error since we are mocking the api token
)
}
assert.Equal(t, api.HostVulnScanPkgManifestResponse{}, subject)
}

func TestMergeHostVulnScanPkgManifestResponses(t *testing.T) {
cases := []struct {
expected api.HostVulnScanPkgManifestResponse
from api.HostVulnScanPkgManifestResponse
to api.HostVulnScanPkgManifestResponse
}{
// empty responses
{expected: api.HostVulnScanPkgManifestResponse{},
from: api.HostVulnScanPkgManifestResponse{},
to: api.HostVulnScanPkgManifestResponse{}},
// responses should return an Ok status
{expected: api.HostVulnScanPkgManifestResponse{
Ok: true},
from: api.HostVulnScanPkgManifestResponse{
Ok: true},
to: api.HostVulnScanPkgManifestResponse{
Ok: false}},
// messages should change only if the previous one is empty or different
{expected: api.HostVulnScanPkgManifestResponse{
Message: "SUCCESS"},
from: api.HostVulnScanPkgManifestResponse{
Message: "SUCCESS"},
to: api.HostVulnScanPkgManifestResponse{
Message: ""}},
{expected: api.HostVulnScanPkgManifestResponse{
Message: "YES"},
from: api.HostVulnScanPkgManifestResponse{
Message: ""},
to: api.HostVulnScanPkgManifestResponse{
Message: "YES"}},
{expected: api.HostVulnScanPkgManifestResponse{
Message: "OLD,NEW"},
from: api.HostVulnScanPkgManifestResponse{
Message: "NEW"},
to: api.HostVulnScanPkgManifestResponse{
Message: "OLD"}},
// merge two responses into one single response 1 + 1 = 2
{
expected: api.HostVulnScanPkgManifestResponse{
Vulns: []api.HostScanPackageVulnDetails{
api.HostScanPackageVulnDetails{}, api.HostScanPackageVulnDetails{},
},
},
from: api.HostVulnScanPkgManifestResponse{
Vulns: []api.HostScanPackageVulnDetails{api.HostScanPackageVulnDetails{}}},
to: api.HostVulnScanPkgManifestResponse{
Vulns: []api.HostScanPackageVulnDetails{api.HostScanPackageVulnDetails{}}},
},
}
for i, kase := range cases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
mergeHostVulnScanPkgManifestResponses(&kase.to, &kase.from)
assert.Equal(t, kase.expected, kase.to)
})
}
}
32 changes: 24 additions & 8 deletions cli/cmd/vuln_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import (
)

var (
// the maximum number of packages per scan request
manifestPkgsCap = 1000

// the package manifest file
pkgManifestFile string

Expand Down Expand Up @@ -100,19 +103,19 @@ To generate a package-manifest from the local host and scan it automatically:

if len(args) != 0 && args[0] != "" {
pkgManifestBytes = []byte(args[0])
cli.Log.Infow("package manifest loaded from arguments", "raw", args[0])
cli.Log.Debugw("package manifest loaded from arguments", "raw", args[0])
} else if pkgManifestFile != "" {
pkgManifestBytes, err = ioutil.ReadFile(pkgManifestFile)
if err != nil {
return errors.Wrap(err, "unable to read file")
}
cli.Log.Infow("package manifest loaded from file", "raw", string(pkgManifestBytes))
cli.Log.Debugw("package manifest loaded from file", "raw", string(pkgManifestBytes))
} else if pkgManifestLocal {
pkgManifest, err = cli.GeneratePackageManifest()
if err != nil {
return errors.Wrap(err, "unable to generate package manifest")
}
cli.Log.Infow("package manifest generated from localhost", "raw", pkgManifest)
cli.Log.Debugw("package manifest generated from localhost", "raw", pkgManifest)
} else {
// avoid asking for a confirmation before launching the editor
var content string
Expand All @@ -125,7 +128,7 @@ To generate a package-manifest from the local host and scan it automatically:
return errors.Wrap(err, "unable to load package manifest from editor")
}
pkgManifestBytes = []byte(content)
cli.Log.Infow("package manifest loaded via editor", "raw", content)
cli.Log.Debugw("package manifest loaded via editor", "raw", content)
}

if len(pkgManifestBytes) != 0 {
Expand All @@ -135,10 +138,23 @@ To generate a package-manifest from the local host and scan it automatically:
}
}

// TODO @afiune check if the package manifest has more than
// 1k packages, if so, make multiple API requests

response, err := cli.LwApi.Vulnerabilities.Host.Scan(pkgManifest)
totalPkgs := len(pkgManifest.OsPkgInfoList)
cli.StartProgress(" Scanning packages...")
cli.Log.Infow("manifest", "total_packages", totalPkgs)
var response api.HostVulnScanPkgManifestResponse
// check if the package manifest has more than the maximum
// number of packages, if so, make multiple API requests
if totalPkgs >= manifestPkgsCap {
cli.Log.Infow("manifest over the limit, splitting up")
cli.Event.Feature = featSplitPkgManifest
cli.Event.AddFeatureField("total_packages", totalPkgs)
response, err = fanOutHostScans(
splitPackageManifest(pkgManifest, manifestPkgsCap)...,
)
} else {
response, err = cli.LwApi.Vulnerabilities.Host.Scan(pkgManifest)
}
cli.StopProgress()
if err != nil {
return errors.Wrap(err, "unable to request an on-demand host vulnerability scan")
}
Expand Down