Skip to content

Commit

Permalink
Fix rego policy verification
Browse files Browse the repository at this point in the history
This commit addresses multiple issues when applying a rego policy
against the payload of an attestation.

1. If `data.signature.deny` evaluated to `true`, the policy verification
would pass. This is obviously unexpected. The code now looks for
`data.signature.allow` instead, and expects it to be `true`.

2. If a query result returned an undefined results, the policy
verification would pass. The code now explicitly checks for this
condition and ensure that if `ResultSet.IsAllowed()` returns `false`,
the policy verification also fails.

3. Improve error messages to assist user in defining correct variable
name and type.

4. Add unit tests to validate behavior and prevent breaking changes in
the future.

Signed-off-by: Luiz Carvalho <[email protected]>
  • Loading branch information
lcarva committed Mar 28, 2022
1 parent 1a54eb5 commit 7969eaf
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 8 deletions.
23 changes: 15 additions & 8 deletions pkg/cosign/rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ import (
"github.com/open-policy-agent/opa/rego"
)

// The query below should meet the following requirements:
// * Provides no Bindings. Do not use a query that sets a variable, e.g. x := data.signature.allow
// * Queries for a single value.
const QUERY = "data.signature.allow"

func ValidateJSON(jsonBody []byte, entrypoints []string) []error {
ctx := context.Background()

r := rego.New(
rego.Query("data.signature.deny"), // hardcoded, ? data.cosign.allow→
rego.Query(QUERY),
rego.Load(entrypoints, nil))

query, err := r.PrepareForEval(ctx)
Expand All @@ -48,21 +53,23 @@ func ValidateJSON(jsonBody []byte, entrypoints []string) []error {
return []error{err}
}

// Ensure the resultset contains a single result where the Expression contains a single value
// which is true and there are no Bindings.
if rs.Allowed() {
return nil
}

var errs []error
for _, result := range rs {
for _, expression := range result.Expressions {
for _, v := range expression.Value.([]interface{}) {
if s, ok := v.(string); ok {
errs = append(errs, fmt.Errorf(s))
} else {
errs = append(errs, fmt.Errorf("%s", v))
}
}
errs = append(errs, fmt.Errorf("expression value, %v, is not true", expression))
}
}

// When rs.Allowed() is not true and len(rs) is 0, the result is undefined. This is a policy
// check failure.
if len(errs) == 0 {
errs = append(errs, fmt.Errorf("result is undefined for query '%s'", QUERY))
}
return errs
}
82 changes: 82 additions & 0 deletions pkg/cosign/rego/rego_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package rego

import (
"fmt"
"os"
"path/filepath"
"testing"
)

const simpleJSONBody = `{
"_type": "https://in-toto.io/Statement/v0.1",
"predicateType": "https://slsa.dev/provenance/v0.2"
}`

func TestValidationJSON(t *testing.T) {
cases := []struct {
name string
jsonBody string
policy string
pass bool
errors []string
}{
{
name: "passing policy",
jsonBody: simpleJSONBody,
policy: `
package signature
allow {
input.predicateType == "https://slsa.dev/provenance/v0.2"
}
`,
pass: true,
},
{
name: "undefined result due to no matching rules",
jsonBody: simpleJSONBody,
policy: `
package signature
allow {
input.predicateType == "https://slsa.dev/provenance/v99.9"
}
`,
pass: false,
errors: []string{"result is undefined for query 'data.signature.allow'"},
},
{
name: "policy query evaluates to false",
jsonBody: simpleJSONBody,
policy: `
package signature
default allow = false
`,
pass: false,
errors: []string{"expression value, false, is not true"},
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
policyFileName := filepath.Join(t.TempDir(), "policy.rego")
if err := os.WriteFile(policyFileName, []byte(tt.policy), 0666); err != nil {
t.Fatalf("Unable to create temporary policy file at %s", policyFileName)
}

if errs := ValidateJSON([]byte(tt.jsonBody), []string{policyFileName}); (errs == nil) != tt.pass {
t.Fatalf("Unexpected result: %v", errs)
} else if errs != nil {
if len(errs) != len(tt.errors) {
t.Fatalf("Expected %d errors, got %d errors: %v", len(tt.errors), len(errs), errs)
}
for i, err := range errs {
if fmt.Sprintf("%s", err) != tt.errors[i] {
t.Errorf("Expected error %q, got %q", tt.errors[i], err)
}
}
}
})
}
}

0 comments on commit 7969eaf

Please sign in to comment.