Skip to content

Commit

Permalink
cmd & format: Adding rego-v1 mode to opa fmt
Browse files Browse the repository at this point in the history
Fixes: open-policy-agent#6297
Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling committed Nov 9, 2023
1 parent 30a244e commit 7370ed1
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 15 deletions.
10 changes: 5 additions & 5 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"github.com/open-policy-agent/opa/ast/location"
)

var regoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")}
var RegoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")}

// Note: This state is kept isolated from the parser so that we
// can do efficient shallow copies of these values when doing a
Expand Down Expand Up @@ -2528,7 +2528,7 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke
}

if p.s.s.RegoV1Compatible() {
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", regoV1CompatibleRef)
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", RegoV1CompatibleRef)
return
}

Expand Down Expand Up @@ -2563,8 +2563,8 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke
func (p *Parser) regoV1Import(imp *Import) {
path := imp.Path.Value.(Ref)

if len(path) == 1 || !path[1].Equal(regoV1CompatibleRef[1]) || len(path) > 2 {
p.errorf(imp.Path.Location, "invalid import, must be `%s`", regoV1CompatibleRef)
if len(path) == 1 || !path[1].Equal(RegoV1CompatibleRef[1]) || len(path) > 2 {
p.errorf(imp.Path.Location, "invalid import, must be `%s`", RegoV1CompatibleRef)
return
}

Expand All @@ -2581,7 +2581,7 @@ func (p *Parser) regoV1Import(imp *Import) {

if p.s.s.HasKeyword(futureKeywords) && !p.s.s.RegoV1Compatible() {
// We have imported future keywords, but they didn't come from another `rego.v1` import.
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", regoV1CompatibleRef)
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", RegoV1CompatibleRef)
return
}

Expand Down
2 changes: 1 addition & 1 deletion ast/parser_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu
switch stmt := stmt.(type) {
case *Import:
mod.Imports = append(mod.Imports, stmt)
if Compare(stmt.Path.Value, regoV1CompatibleRef) == 0 {
if Compare(stmt.Path.Value, RegoV1CompatibleRef) == 0 {
mod.regoV1Compatible = true
}
case *Rule:
Expand Down
10 changes: 6 additions & 4 deletions cmd/fmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type fmtCommandParams struct {
list bool
diff bool
fail bool
regoV1 bool
}

var fmtParams = fmtCommandParams{}
Expand Down Expand Up @@ -57,7 +58,7 @@ code if a file would be reformatted.`,
func opaFmt(args []string) int {

if len(args) == 0 {
if err := formatStdin(os.Stdin, os.Stdout); err != nil {
if err := formatStdin(&fmtParams, os.Stdin, os.Stdout); err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}
Expand Down Expand Up @@ -108,7 +109,7 @@ func formatFile(params *fmtCommandParams, out io.Writer, filename string, info o
return newError("failed to open file: %v", err)
}

formatted, err := format.Source(filename, contents)
formatted, err := format.SourceWithOpts(filename, contents, format.Opts{RegoV1: params.regoV1})
if err != nil {
return newError("failed to parse Rego source file: %v", err)
}
Expand Down Expand Up @@ -166,14 +167,14 @@ func formatFile(params *fmtCommandParams, out io.Writer, filename string, info o
return nil
}

func formatStdin(r io.Reader, w io.Writer) error {
func formatStdin(params *fmtCommandParams, r io.Reader, w io.Writer) error {

contents, err := io.ReadAll(r)
if err != nil {
return err
}

formatted, err := format.Source("stdin", contents)
formatted, err := format.SourceWithOpts("stdin", contents, format.Opts{RegoV1: params.regoV1})
if err != nil {
return err
}
Expand Down Expand Up @@ -233,5 +234,6 @@ func init() {
formatCommand.Flags().BoolVarP(&fmtParams.list, "list", "l", false, "list all files who would change when formatted")
formatCommand.Flags().BoolVarP(&fmtParams.diff, "diff", "d", false, "only display a diff of the changes")
formatCommand.Flags().BoolVar(&fmtParams.fail, "fail", false, "non zero exit code on reformat")
formatCommand.Flags().BoolVar(&fmtParams.regoV1, "rego-v1", false, "format as Rego v1")
RootCommand.AddCommand(formatCommand)
}
49 changes: 44 additions & 5 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Opts struct {
// of partial evaluation, arguments maybe have been shuffled around, but still
// carry along their original source locations.
IgnoreLocations bool

RegoV1 bool
}

// defaultLocationFile is the file name used in `Ast()` for terms
Expand All @@ -36,12 +38,16 @@ const defaultLocationFile = "__format_default__"
// Rego module. If they don't, Source will return an error resulting from the attempt
// to parse the bytes.
func Source(filename string, src []byte) ([]byte, error) {
return SourceWithOpts(filename, src, Opts{})
}

func SourceWithOpts(filename string, src []byte, opts Opts) ([]byte, error) {
module, err := ast.ParseModule(filename, string(src))
if err != nil {
return nil, err
}

formatted, err := Ast(module)
formatted, err := AstWithOpts(module, opts)
if err != nil {
return nil, fmt.Errorf("%s: %v", filename, err)
}
Expand Down Expand Up @@ -81,6 +87,8 @@ type fmtOpts struct {
// for ref heads -- if they do, we'll print all of them in a different way
// than if they don't.
refHeads bool

regoV1 bool
}

func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
Expand All @@ -99,6 +107,8 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {

o := fmtOpts{}

o.regoV1 = opts.RegoV1

// Preprocess the AST. Set any required defaults and calculate
// values required for printing the formatted output.
ast.WalkNodes(x, func(x ast.Node) bool {
Expand Down Expand Up @@ -154,7 +164,10 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {

switch x := x.(type) {
case *ast.Module:
if moduleIsRegoV1Compatible(x) {
if o.regoV1 {
x.Imports = ensureRegoV1Import(x.Imports)
}
if o.regoV1 || moduleIsRegoV1Compatible(x) {
x.Imports = future.FilterFutureImports(x.Imports)
} else {
for kw := range extraFutureKeywordImports {
Expand Down Expand Up @@ -360,7 +373,7 @@ func (w *writer) writeRule(rule *ast.Rule, isElse bool, o fmtOpts, comments []*a
return comments
}

if o.ifs && partialSetException {
if (o.regoV1 || o.ifs) && partialSetException {
w.write(" if")
if len(rule.Body) == 1 {
if rule.Body[0].Location.Row == rule.Head.Location.Row {
Expand Down Expand Up @@ -503,14 +516,22 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fm
if head.Value != nil &&
(head.Key != nil || ast.Compare(head.Value, ast.BooleanTerm(true)) != 0 || isExpandedConst || isDefault) {

if head.Location == head.Value.Location && head.Name != "else" {
// in rego v1, explicitly print value for ref-head constants that aren't partial set assignments, e.g.:
// * a -> parser error, won't reach here
// * a.b -> a contains "b"
// * a.b.c -> a.b.c := true
// * a.b.c.d -> a.b.c.d := true
isRegoV1RefConst := o.regoV1 && isExpandedConst && head.Key == nil // && len(head.Reference) > 2

if head.Location == head.Value.Location && head.Name != "else" && !isRegoV1RefConst {
// If the value location is the same as the location of the head,
// we know that the value is generated, i.e. f(1)
// Don't print the value (` = true`) as it is implied.
return comments
}

if head.Assign {
if head.Assign || o.regoV1 {
// preserve assignment operator, and enforce it if formatting for Rego v1
w.write(" := ")
} else {
w.write(" = ")
Expand Down Expand Up @@ -1407,6 +1428,24 @@ func ensureFutureKeywordImport(imps []*ast.Import, kw string) []*ast.Import {
return append(imps, imp)
}

func ensureRegoV1Import(imps []*ast.Import) []*ast.Import {
return ensureImport(imps, ast.RegoV1CompatibleRef)
}

func ensureImport(imps []*ast.Import, path ast.Ref) []*ast.Import {
for _, imp := range imps {
p := imp.Path.Value.(ast.Ref)
if p.Equal(path) {
return imps
}
}
imp := &ast.Import{
Path: ast.NewTerm(path),
}
imp.Location = defaultLocation(imp)
return append(imps, imp)
}

// ArgErrDetail but for `fmt` checks since compiler has not run yet.
type ArityFormatErrDetail struct {
Have []string `json:"have"`
Expand Down

0 comments on commit 7370ed1

Please sign in to comment.