Skip to content

Commit

Permalink
lsp: implement used refs completions (#794)
Browse files Browse the repository at this point in the history
This implements a completions provider that makes suggestions based on
refs users have already typed into their files.

Signed-off-by: Charlie Egan <[email protected]>
  • Loading branch information
charlieegan3 committed Jun 5, 2024
1 parent b4eaad3 commit fdda4b5
Show file tree
Hide file tree
Showing 15 changed files with 407 additions and 18 deletions.
27 changes: 27 additions & 0 deletions internal/lsp/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ type Cache struct {
// fileRefs is expected to be updated when a file is successfully parsed.
fileRefs map[string]map[string]types.Ref
fileRefMu sync.Mutex

// usedRefs is a map of file URI to a list of string ref names used in that file.
// These are intended to be used for completions in that file.
usedRefs map[string][]string
usedRefsMu sync.Mutex
}

func NewCache() *Cache {
Expand All @@ -57,6 +62,8 @@ func NewCache() *Cache {
builtinPositionsFile: make(map[string]map[uint][]types.BuiltinPosition),

fileRefs: make(map[string]map[string]types.Ref),

usedRefs: make(map[string][]string),
}
}

Expand Down Expand Up @@ -233,6 +240,22 @@ func (c *Cache) GetAllFileRefs() map[string]map[string]types.Ref {
return maps.Clone(c.fileRefs)
}

func (c *Cache) SetUsedRefs(uri string, items []string) {
c.usedRefsMu.Lock()
defer c.usedRefsMu.Unlock()

c.usedRefs[uri] = items
}

func (c *Cache) GetUsedRefs(uri string) ([]string, bool) {
c.usedRefsMu.Lock()
defer c.usedRefsMu.Unlock()

refs, ok := c.usedRefs[uri]

return refs, ok
}

// Delete removes all cached data for a given URI.
func (c *Cache) Delete(uri string) {
c.fileContentsMu.Lock()
Expand Down Expand Up @@ -262,6 +285,10 @@ func (c *Cache) Delete(uri string) {
c.fileRefMu.Lock()
delete(c.fileRefs, uri)
c.fileRefMu.Unlock()

c.usedRefsMu.Lock()
delete(c.usedRefs, uri)
c.usedRefsMu.Unlock()
}

func UpdateCacheForURIFromDisk(cache *Cache, uri, path string) (string, error) {
Expand Down
1 change: 1 addition & 0 deletions internal/lsp/completions/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func NewDefaultManager(c *cache.Cache) *Manager {
m.RegisterProvider(&providers.RuleHeadKeyword{})
m.RegisterProvider(&providers.Input{})
m.RegisterProvider(&providers.CommonRule{})
m.RegisterProvider(&providers.UsedRefs{})

return m
}
Expand Down
2 changes: 1 addition & 1 deletion internal/lsp/completions/providers/commonrule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ deny := false
}

c.SetModule(testCaseFileURI, mod)
c.SetFileRefs(testCaseFileURI, refs.ForModule(mod))
c.SetFileRefs(testCaseFileURI, refs.DefinedInModule(mod))

p := &CommonRule{}

Expand Down
4 changes: 2 additions & 2 deletions internal/lsp/completions/providers/packagerefs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import
mod := parse.MustParseModule(contents)
c.SetModule(uri, mod)

c.SetFileRefs(uri, refs.ForModule(mod))
c.SetFileRefs(uri, refs.DefinedInModule(mod))
}

p := &PackageRefs{}
Expand Down Expand Up @@ -119,7 +119,7 @@ import
mod := parse.MustParseModule(contents)
c.SetModule(uri, mod)

c.SetFileRefs(uri, refs.ForModule(mod))
c.SetFileRefs(uri, refs.DefinedInModule(mod))
}

p := &PackageRefs{}
Expand Down
2 changes: 1 addition & 1 deletion internal/lsp/completions/providers/rulehead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ funckyfunc := true

c.SetFileContents(uri, contents)
c.SetModule(uri, mod)
c.SetFileRefs(uri, refs.ForModule(mod))
c.SetFileRefs(uri, refs.DefinedInModule(mod))
}

p := &RuleHead{}
Expand Down
2 changes: 1 addition & 1 deletion internal/lsp/completions/providers/rulerefs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ deny := false

c.SetModule(uri, mod)

c.SetFileRefs(uri, refs.ForModule(mod))
c.SetFileRefs(uri, refs.DefinedInModule(mod))
}

c.SetFileContents("file:///example.rego", currentlyEditingFileContents+"\n\nallow if ")
Expand Down
66 changes: 66 additions & 0 deletions internal/lsp/completions/providers/used_refs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package providers

import (
"strings"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/types/completion"
)

// UsedRefs is a completion provider that provides completions for refs
// that have already been typed into a module.
type UsedRefs struct{}

func (*UsedRefs) Run(c *cache.Cache, params types.CompletionParams, _ *Options) ([]types.CompletionItem, error) {
fileURI := params.TextDocument.URI

lines, currentLine := completionLineHelper(c, fileURI, params.Position.Line)
if len(lines) < 1 || currentLine == "" {
return []types.CompletionItem{}, nil
}

if !inRuleBody(currentLine) {
return []types.CompletionItem{}, nil
}

words := patternWhiteSpace.Split(strings.TrimSpace(currentLine), -1)
lastWord := words[len(words)-1]

refNames, ok := c.GetUsedRefs(fileURI)
if !ok {
return []types.CompletionItem{}, nil
}

items := []types.CompletionItem{}

for _, ref := range refNames {
if !strings.HasPrefix(ref, lastWord) {
continue
}

items = append(items, types.CompletionItem{
Label: ref,
Kind: completion.Reference,
Detail: "Existing ref used in module",
Documentation: &types.MarkupContent{
Kind: "markdown",
Value: `Existing ref used in module`,
},
TextEdit: &types.TextEdit{
Range: types.Range{
Start: types.Position{
Line: params.Position.Line,
Character: params.Position.Character - uint(len(lastWord)),
},
End: types.Position{
Line: params.Position.Line, Character: params.Position.Character,
},
},
NewText: ref,
},
})
}

return items, nil
}
101 changes: 101 additions & 0 deletions internal/lsp/completions/providers/used_refs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package providers

import (
"context"
"slices"
"strings"
"testing"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/completions/refs"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/parse"
)

func TestUsedRefs(t *testing.T) {
t.Parallel()

c := cache.NewCache()

currentlyEditingFileContents := `package example
import rego.v1
import data.foo as wow
import data.bar
allow if input.user == "admin"
allow if data.users.admin == input.user
deny contains wow.password if {
input.magic == true
}
deny contains input.parrot if {
bar.parrot != "a bird"
}
`

uri := "file:///example.rego"

mod, err := parse.Module(uri, currentlyEditingFileContents)
if err != nil {
t.Fatalf("Unexpected error when parsing %s contents: %v", uri, err)
}

c.SetModule(uri, mod)
c.SetFileRefs(uri, refs.DefinedInModule(mod))

refNames, err := refs.UsedInModule(context.Background(), mod)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}

c.SetUsedRefs(uri, refNames)

c.SetFileContents(uri, currentlyEditingFileContents+`
allow if {
true == i
}`)

p := &UsedRefs{}

completionParams := types.CompletionParams{
TextDocument: types.TextDocumentIdentifier{
URI: uri,
},
Position: types.Position{
Line: 20,
Character: 11,
},
}

completions, err := p.Run(c, completionParams, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

expectedRefs := []string{
"input.magic",
"input.parrot",
"input.user",
}
slices.Sort(expectedRefs)

foundRefs := make([]string, len(completions))

for i, c := range completions {
foundRefs[i] = c.Label
}

slices.Sort(foundRefs)

if !slices.Equal(expectedRefs, foundRefs) {
t.Fatalf(
"Expected completions to be\n%s\ngot:\n%s",
strings.Join(expectedRefs, "\n"),
strings.Join(foundRefs, "\n"),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
"github.com/styrainc/regal/internal/lsp/types"
)

// ForModule returns a map of refs and details about them to be used in completions that
// DefinedInModule returns a map of refs and details about them to be used in completions that
// were found in the given module.
func ForModule(module *ast.Module) map[string]types.Ref {
func DefinedInModule(module *ast.Module) map[string]types.Ref {
modKey := module.Package.Path.String()

// first, create a reference for the package using the metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestForModule_Package(t *testing.T) {
package example
`)

items := ForModule(mod)
items := DefinedInModule(mod)

expectedRefs := map[string]types.Ref{
"data.example": {
Expand Down Expand Up @@ -123,7 +123,7 @@ deny contains "strings" if true
pi := 3.14
`)

items := ForModule(mod)
items := DefinedInModule(mod)

expectedRefs := map[string]types.Ref{
"data.example": {
Expand Down
33 changes: 33 additions & 0 deletions internal/lsp/completions/refs/rego/ref_names.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package lsp.completions

import rego.v1

import data.regal.ast

# ref_names returns a list of ref names that are used in the module.
# built-in functions are not included as they are provided by another completions provider.
# imports are not included as we need to use the imported_identifier instead
# (i.e. maybe an alias).
ref_names contains name if {
some ref in ast.all_refs

name := ast.ref_to_string(ref.value)

not name in ast.builtin_functions_called
not name in imports
}

# if a user has imported data.foo, then foo should be suggested.
# if they have imported data.foo as bar, then bar should be suggested.
# this also has the benefit of skipping future.* and rego.v1 as
# imported_identifiers will only match data.* and input.*
ref_names contains name if {
some name in ast.imported_identifiers
}

# imports are not shown as we need to show the imported alias instead
imports contains ref if {
some imp in ast.imports

ref := ast.ref_to_string(imp.path.value)
}
Loading

0 comments on commit fdda4b5

Please sign in to comment.