Skip to content

Commit

Permalink
lsp: completions for rule heads (#770)
Browse files Browse the repository at this point in the history
Fixes: #765

Signed-off-by: Charlie Egan <[email protected]>
  • Loading branch information
charlieegan3 committed May 29, 2024
1 parent 48aa1db commit 51bd977
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 0 deletions.
1 change: 1 addition & 0 deletions internal/lsp/completions/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func NewDefaultManager(c *cache.Cache) *Manager {
m.RegisterProvider(&providers.RegoV1{})
m.RegisterProvider(&providers.PackageRefs{})
m.RegisterProvider(&providers.RuleRefs{})
m.RegisterProvider(&providers.RuleHead{})

return m
}
Expand Down
113 changes: 113 additions & 0 deletions internal/lsp/completions/providers/rulehead.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package providers

import (
"strings"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/types/completion"
)

// RuleHead is a completion provider that returns completions for
// rules found in the same package at the start of a line, so
// when adding new heads, the user can easily add new ones.
type RuleHead struct{}

func (*RuleHead) Run(
c *cache.Cache,
params types.CompletionParams,
_ *Options,
) ([]types.CompletionItem, error) {
fileURI := params.TextDocument.URI

lines, currentLine := completionLineHelper(c, fileURI, params.Position.Line)
if len(lines) < 1 {
return []types.CompletionItem{}, nil
}

if currentLine != "" {
words := patternWhiteSpace.Split(currentLine, -1)

// this provider only suggests rules at the start of a line while
// typing the first word
if len(words) != 1 {
return []types.CompletionItem{}, nil
}

// if the first word is not at the start of the line, then we
// assume the cursor is in a rule body and we exit
if strings.Index(currentLine, words[0]) != 0 {
return []types.CompletionItem{}, nil
}
}

// some version of a parsed mod is needed here to filter refs to suggest
// based on import statements
mod, ok := c.GetModule(fileURI)
if !ok {
return nil, nil
}

modPrefix := mod.Package.Path.String() + "."

refsFromPackage := make(map[string]types.Ref)

// we gather refs from other files in case the package has been defined
// in more than one file
for _, refs := range c.GetAllFileRefs() {
for key, ref := range refs {
// this provider only suggests rules
if ref.Kind != types.Rule && ref.Kind != types.ConstantRule && ref.Kind != types.Function {
continue
}

// only rules from the current package are suggested
if !strings.HasPrefix(key, modPrefix) {
continue
}

refsFromPackage[strings.TrimPrefix(key, modPrefix)] = ref
}
}

items := make([]types.CompletionItem, 0)

for key, ref := range refsFromPackage {
symbol := completion.Variable
detail := "Rule"

switch {
case ref.Kind == types.ConstantRule:
symbol = completion.Constant
detail = "Constant Rule"
case ref.Kind == types.Function:
symbol = completion.Function
detail = "Function"
}

items = append(items, types.CompletionItem{
Label: key,
Kind: symbol,
Detail: detail,
Documentation: &types.MarkupContent{
Kind: "markdown",
Value: ref.Description,
},
TextEdit: &types.TextEdit{
Range: types.Range{
Start: types.Position{
Line: params.Position.Line,
Character: 0,
},
End: types.Position{
Line: params.Position.Line,
Character: uint(len(currentLine)),
},
},
NewText: key,
},
})
}

return items, nil
}
81 changes: 81 additions & 0 deletions internal/lsp/completions/providers/rulehead_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package providers

import (
"slices"
"testing"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/completions/refs"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/parse"
)

func TestRuleHead(t *testing.T) {
t.Parallel()

c := cache.NewCache()

regoFiles := map[string]string{
"file:///foo/foo.rego": `package foo
import rego.v1
default allow := false
allow if count(deny) == 0
deny contains message if {
true
}
_internal := true
funckyfunc := true
`,
}

for uri, contents := range regoFiles {
mod, err := parse.Module(uri, contents)
if err != nil {
t.Fatalf("Unexpected error when parsing %s contents: %v", uri, err)
}

c.SetFileContents(uri, contents)
c.SetModule(uri, mod)
c.SetFileRefs(uri, refs.ForModule(mod))
}

p := &RuleHead{}

completionParams := types.CompletionParams{
TextDocument: types.TextDocumentIdentifier{
URI: "file:///foo/foo.rego",
},
Position: types.Position{
Line: 16,
Character: 0,
},
}

completions, err := p.Run(c, completionParams, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

expectedRefs := []string{"allow", "deny", "_internal", "funckyfunc"}
slices.Sort(expectedRefs)

foundRefs := make([]string, len(completions))

for i, c := range completions {
foundRefs[i] = c.Label
}

slices.Sort(foundRefs)

if !slices.Equal(expectedRefs, foundRefs) {
t.Fatalf("Expected completions to be %v, got: %v", expectedRefs, foundRefs)
}
}

0 comments on commit 51bd977

Please sign in to comment.