From acbdb880d0cd1ea3e8963c37b341589f9691c1c4 Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Thu, 15 Aug 2024 11:30:51 +0100 Subject: [PATCH] lsp: Implement bundle use in workspace Eval (#987) * WIP * lsp/bundles: add a bundle cache This cache can be used to maintain a cache of bundles in memory that are found in the workspace. Signed-off-by: Charlie Egan * lsp: Implement bundle use in workspace Eval Signed-off-by: Charlie Egan * PR Feedback --------- Signed-off-by: Charlie Egan --- .golangci.yaml | 1 + internal/lsp/bundles/bundles.go | 35 ++++ internal/lsp/bundles/bundles_test.go | 101 ++++++++++++ internal/lsp/bundles/cache.go | 235 +++++++++++++++++++++++++++ internal/lsp/bundles/cache_test.go | 165 +++++++++++++++++++ internal/lsp/eval.go | 43 ++++- internal/lsp/server.go | 22 ++- 7 files changed, 591 insertions(+), 11 deletions(-) create mode 100644 internal/lsp/bundles/bundles.go create mode 100644 internal/lsp/bundles/bundles_test.go create mode 100644 internal/lsp/bundles/cache.go create mode 100644 internal/lsp/bundles/cache_test.go diff --git a/.golangci.yaml b/.golangci.yaml index a03b2292..0addac46 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -25,6 +25,7 @@ linters: - nolintlint - depguard - gomoddirectives # need replacements for wasip1 + - execinquery # deprecated linters-settings: tagliatelle: case: diff --git a/internal/lsp/bundles/bundles.go b/internal/lsp/bundles/bundles.go new file mode 100644 index 00000000..47e26a55 --- /dev/null +++ b/internal/lsp/bundles/bundles.go @@ -0,0 +1,35 @@ +package bundles + +import ( + "fmt" + "os" + "path/filepath" + "slices" + + "github.com/open-policy-agent/opa/bundle" +) + +// LoadDataBundle loads a bundle from the given path but only includes data +// files. The path must contain a bundle manifest file. +func LoadDataBundle(path string) (bundle.Bundle, error) { + if _, err := os.Stat(filepath.Join(path, ".manifest")); err != nil { + return bundle.Bundle{}, fmt.Errorf("manifest file was not found at bundle path %q", path) + } + + b, err := bundle.NewCustomReader(bundle.NewDirectoryLoader(path).WithFilter(dataFileLoaderFilter)).Read() + if err != nil { + return bundle.Bundle{}, fmt.Errorf("failed to read bundle: %w", err) + } + + return b, nil +} + +func dataFileLoaderFilter(abspath string, info os.FileInfo, _ int) bool { + if info.IsDir() { + return false + } + + basename := filepath.Base(abspath) + + return !slices.Contains([]string{".manifest", "data.json", "data.yml", "data.yaml"}, basename) +} diff --git a/internal/lsp/bundles/bundles_test.go b/internal/lsp/bundles/bundles_test.go new file mode 100644 index 00000000..f63c9846 --- /dev/null +++ b/internal/lsp/bundles/bundles_test.go @@ -0,0 +1,101 @@ +package bundles + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestLoadDataBundle(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + path string + files map[string]string + expectedData any + }{ + "simple bundle": { + path: "foo", + files: map[string]string{ + "foo/.manifest": `{"roots":["foo"]}`, + "foo/data.json": `{"foo": "bar"}`, + }, + expectedData: map[string]any{ + "foo": "bar", + }, + }, + "nested bundle": { + path: "foo", + files: map[string]string{ + "foo/.manifest": `{"roots":["foo", "bar"]}`, + "foo/data.yml": `foo: bar`, + "foo/bar/data.yaml": `bar: baz`, + }, + expectedData: map[string]any{ + "foo": "bar", + "bar": map[string]any{ + "bar": "baz", + }, + }, + }, + "array data": { + path: "foo", + files: map[string]string{ + "foo/.manifest": `{"roots":["bar"]}`, + "foo/bar/data.json": `[{"foo": "bar"}]`, + }, + expectedData: map[string]any{ + "bar": []any{ + map[string]any{ + "foo": "bar", + }, + }, + }, + }, + "rego files": { + path: "foo", + files: map[string]string{ + "foo/.manifest": `{"roots":["foo"]}`, + "food/rego.rego": `package foo`, + }, + expectedData: map[string]any{}, + }, + } + + for testCase, testData := range testCases { + t.Run(testCase, func(t *testing.T) { + t.Parallel() + + workspacePath := t.TempDir() + + // create the workspace state + for file, contents := range testData.files { + filePath := filepath.Join(workspacePath, file) + + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("failed to create directory %s: %v", dir, err) + } + + err := os.WriteFile(filePath, []byte(contents), 0o600) + if err != nil { + t.Fatalf("failed to write file %s: %v", filePath, err) + } + } + + b, err := LoadDataBundle(filepath.Join(workspacePath, testData.path)) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(b.Data, testData.expectedData) { + t.Fatalf("expected data to be %v, but got %v", testData.expectedData, b.Data) + } + + if len(b.Modules) != 0 { + t.Fatalf("expected no modules, but got %d", len(b.Modules)) + } + }) + } +} diff --git a/internal/lsp/bundles/cache.go b/internal/lsp/bundles/cache.go new file mode 100644 index 00000000..da0edf01 --- /dev/null +++ b/internal/lsp/bundles/cache.go @@ -0,0 +1,235 @@ +package bundles + +import ( + "bytes" + //nolint:gosec + "crypto/md5" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/open-policy-agent/opa/bundle" + + "github.com/styrainc/regal/internal/util" +) + +// Cache is a struct that maintains a number of bundles in memory and +// provides a way to refresh them when the source files change. +type Cache struct { + workspacePath string + bundles map[string]*cacheBundle + errorLog io.Writer +} + +type CacheOptions struct { + WorkspacePath string + ErrorLog io.Writer +} + +func NewCache(opts *CacheOptions) *Cache { + workspacePath := opts.WorkspacePath + + if !strings.HasSuffix(workspacePath, string(filepath.Separator)) { + workspacePath += string(filepath.Separator) + } + + c := &Cache{ + workspacePath: workspacePath, + bundles: make(map[string]*cacheBundle), + } + + if opts.ErrorLog != nil { + c.errorLog = opts.ErrorLog + } + + return c +} + +// Refresh walks the workspace path and loads or refreshes any bundles that +// have changed since the last refresh. +func (c *Cache) Refresh() ([]string, error) { + if c.workspacePath == "" { + return nil, errors.New("workspace path is empty") + } + + // find all the bundle roots that are currently present on disk + var foundBundleRoots []string + + err := filepath.Walk(c.workspacePath, func(path string, info os.FileInfo, _ error) error { + if info.IsDir() && (info.Name() == ".git" || info.Name() == ".idea") { + return filepath.SkipDir + } + + if info.IsDir() { + return nil + } + + if filepath.Base(path) == ".manifest" { + foundBundleRoots = append( + foundBundleRoots, + strings.TrimPrefix(filepath.Dir(path), c.workspacePath), + ) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to walk workspace path: %w", err) + } + + var refreshedBundles []string + + // refresh any bundles that have changed + for _, root := range foundBundleRoots { + if _, ok := c.bundles[root]; !ok { + c.bundles[root] = &cacheBundle{} + } + + refreshed, err := c.bundles[root].Refresh(filepath.Join(c.workspacePath, root)) + if err != nil { + if c.errorLog != nil { + fmt.Fprintf(c.errorLog, "failed to refresh bundle %q: %v\n", root, err) + } + + continue + } + + if refreshed { + refreshedBundles = append(refreshedBundles, root) + } + } + + // remove any bundles that are no longer present on disk + for root := range c.bundles { + found := false + + for _, foundRoot := range foundBundleRoots { + if root == foundRoot { + found = true + + break + } + } + + if !found { + delete(c.bundles, root) + } + } + + return refreshedBundles, nil +} + +// List returns a list of all the bundle roots that are currently present in +// the cache. +func (c *Cache) List() []string { + return util.Keys(c.bundles) +} + +// Get returns the bundle for the given root from the cache. +func (c *Cache) Get(root string) (bundle.Bundle, bool) { + b, ok := c.bundles[root] + if !ok { + return bundle.Bundle{}, false + } + + return b.bundle, true +} + +// All returns all the bundles in the cache. +func (c *Cache) All() map[string]bundle.Bundle { + bundles := make(map[string]bundle.Bundle) + + for root, cacheBundle := range c.bundles { + bundles[root] = cacheBundle.bundle + } + + return bundles +} + +// cacheBundle is an internal struct that holds a bundle.Bundle and the MD5 +// hash of each source file in the bundle. Hashes are used to determine if +// the bundle should be reloaded. +type cacheBundle struct { + bundle bundle.Bundle + sourceDigests map[string][]byte +} + +// Refresh loads the bundle from disk and updates the cache if any of the +// source files have changed since the last refresh. +func (c *cacheBundle) Refresh(path string) (bool, error) { + onDiskSourceDigests := make(map[string][]byte) + + // walk the bundle path and calculate the MD5 hash of each file on disk + // at the moment + err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() || dataFileLoaderFilter(path, info, 0) { + return nil + } + + hash, err := calculateMD5(path) + if err != nil { + return err + } + + onDiskSourceDigests[path] = hash + + return nil + }) + if err != nil { + return false, fmt.Errorf("failed to walk bundle path %q: %w", path, err) + } + + // compare the files on disk with the files that have been seen before + // and return without reloading the bundle if there have been no changes + if len(onDiskSourceDigests) == len(c.sourceDigests) { + changed := false + + for path, hash := range onDiskSourceDigests { + if !bytes.Equal(hash, c.sourceDigests[path]) { + changed = true + + break + } + } + + if !changed { + return false, nil + } + } + + // if there has been any change in any of the source files, then + // reload the bundle + c.bundle, err = LoadDataBundle(path) + if err != nil { + return false, fmt.Errorf("failed to load bundle %q: %w", path, err) + } + + // update the bundle's sourceDigests to the new on-disk state after a + // successful refresh + c.sourceDigests = onDiskSourceDigests + + return true, nil +} + +func calculateMD5(filePath string) ([]byte, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open file %q: %w", filePath, err) + } + defer file.Close() + + // nolint:gosec + hash := md5.New() + if _, err := io.Copy(hash, file); err != nil { + return nil, fmt.Errorf("failed to calculate MD5 hash for file %q: %w", filePath, err) + } + + return hash.Sum(nil), nil +} diff --git a/internal/lsp/bundles/cache_test.go b/internal/lsp/bundles/cache_test.go new file mode 100644 index 00000000..2068a7d4 --- /dev/null +++ b/internal/lsp/bundles/cache_test.go @@ -0,0 +1,165 @@ +package bundles + +import ( + "os" + "path/filepath" + "reflect" + "slices" + "testing" +) + +func TestRefresh(t *testing.T) { + t.Parallel() + + workspacePath := t.TempDir() + + // create the initial filesystem state + files := map[string]string{ + "foo/.manifest": `{"roots":["foo"]}`, + "foo/data.json": `{"foo": "bar"}`, + } + + writeFiles := func(files map[string]string) { + for file, contents := range files { + filePath := filepath.Join(workspacePath, file) + + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("failed to create directory %s: %v", dir, err) + } + + err := os.WriteFile(filePath, []byte(contents), 0o600) + if err != nil { + t.Fatalf("failed to write file %s: %v", filePath, err) + } + } + } + + writeFiles(files) + + c := NewCache(&CacheOptions{WorkspacePath: workspacePath}) + + // perform the first load of the bundles + refreshedBundles, err := c.Refresh() + if err != nil { + t.Fatalf("failed to refresh cache: %v", err) + } + + if !slices.Equal(refreshedBundles, []string{"foo"}) { + t.Fatalf("unexpected refreshed bundles: %v", refreshedBundles) + } + + if len(c.List()) != 1 { + t.Fatalf("unexpected number of bundles: %d", len(c.List())) + } + + fooBundle, ok := c.Get("foo") + if !ok { + t.Fatalf("failed to get bundle foo") + } + + if !reflect.DeepEqual(fooBundle.Data, map[string]any{"foo": "bar"}) { + t.Fatalf("unexpected bundle data: %v", fooBundle.Data) + } + + if fooBundle.Manifest.Roots == nil { + t.Fatalf("unexpected bundle roots: %v", fooBundle.Manifest.Roots) + } + + if !reflect.DeepEqual(*fooBundle.Manifest.Roots, []string{"foo"}) { + t.Fatalf("unexpected bundle roots: %v", *fooBundle.Manifest.Roots) + } + + // perform the second load of the bundles, after no changes on disk + refreshedBundles, err = c.Refresh() + if err != nil { + t.Fatalf("failed to refresh cache: %v", err) + } + + if !slices.Equal(refreshedBundles, []string{}) { + t.Fatalf("unexpected refreshed bundles: %v", refreshedBundles) + } + + // add a new unrelated file + writeFiles( + map[string]string{ + "foo/foo.rego": `package wow`, + }, + ) + + // perform the third load of the bundles, after adding a new unrelated file + refreshedBundles, err = c.Refresh() + if err != nil { + t.Fatalf("failed to refresh cache: %v", err) + } + + if !slices.Equal(refreshedBundles, []string{}) { + t.Fatalf("unexpected refreshed bundles: %v", refreshedBundles) + } + + // update the data in the bundle + writeFiles( + map[string]string{ + "foo/data.json": `{"foo": "baz"}`, + }, + ) + + refreshedBundles, err = c.Refresh() + if err != nil { + t.Fatalf("failed to refresh cache: %v", err) + } + + if !slices.Equal(refreshedBundles, []string{"foo"}) { + t.Fatalf("unexpected refreshed bundles: %v", refreshedBundles) + } + + fooBundle, ok = c.Get("foo") + if !ok { + t.Fatalf("failed to get bundle foo") + } + + if !reflect.DeepEqual(fooBundle.Data, map[string]any{"foo": "baz"}) { + t.Fatalf("unexpected bundle data: %v", fooBundle.Data) + } + + // create a new bundle + writeFiles( + map[string]string{ + "bar/.manifest": `{"roots":["bar"]}`, + "bar/data.json": `{"bar": true}`, + }, + ) + + refreshedBundles, err = c.Refresh() + if err != nil { + t.Fatalf("failed to refresh cache: %v", err) + } + + if !slices.Equal(refreshedBundles, []string{"bar"}) { + t.Fatalf("unexpected refreshed bundles: %v", refreshedBundles) + } + + barBundle, ok := c.Get("bar") + if !ok { + t.Fatalf("failed to get bundle foo") + } + + if !reflect.DeepEqual(barBundle.Data, map[string]any{"bar": true}) { + t.Fatalf("unexpected bundle data: %v", fooBundle.Data) + } + + // remove the foo bundle + err = os.RemoveAll(filepath.Join(workspacePath, "foo")) + if err != nil { + t.Fatalf("failed to remove foo bundle: %v", err) + } + + _, err = c.Refresh() + if err != nil { + t.Fatalf("failed to refresh cache: %v", err) + } + + if !slices.Equal(c.List(), []string{"bar"}) { + t.Fatalf("unexpected bundle list: %v", c.List()) + } +} diff --git a/internal/lsp/eval.go b/internal/lsp/eval.go index 44c1c995..b28205dd 100644 --- a/internal/lsp/eval.go +++ b/internal/lsp/eval.go @@ -25,6 +25,7 @@ func (l *LanguageServer) Eval( query string, input io.Reader, printHook print.Hook, + dataBundles map[string]bundle.Bundle, ) (rego.ResultSet, error) { modules := l.cache.GetAllModules() moduleFiles := make([]bundle.ModuleFile, 0, len(modules)) @@ -37,16 +38,31 @@ func (l *LanguageServer) Eval( }) } - bd := bundle.Bundle{ - Data: make(map[string]any), + allBundles := make(map[string]bundle.Bundle) + + for k, v := range dataBundles { + if v.Manifest.Roots == nil { + l.logError(fmt.Errorf("bundle %s has no roots and will be skipped", k)) + + continue + } + + allBundles[k] = v + } + + allBundles["workspace"] = bundle.Bundle{ Manifest: bundle.Manifest{ - Roots: &[]string{""}, + // there is no data in this bundle so the roots are not used, + // however, roots must be set. + Roots: &[]string{"workspace"}, Metadata: map[string]any{"name": "workspace"}, }, Modules: moduleFiles, + // Data is all sourced from the dataBundles instead + Data: make(map[string]any), } - regoArgs := prepareRegoArgs(ast.MustParseBody(query), bd, printHook) + regoArgs := prepareRegoArgs(ast.MustParseBody(query), allBundles, printHook) pq, err := rego.New(regoArgs...).PrepareForEval(ctx) if err != nil { @@ -103,7 +119,12 @@ func (l *LanguageServer) EvalWorkspacePath( hook := PrintHook{Output: make(map[int][]string)} - result, err := l.Eval(ctx, resultQuery, input, hook) + var bs map[string]bundle.Bundle + if l.bundleCache != nil { + bs = l.bundleCache.All() + } + + result, err := l.Eval(ctx, resultQuery, input, hook, bs) if err != nil { return EvalPathResult{}, fmt.Errorf("failed evaluating query: %w", err) } @@ -120,15 +141,21 @@ func (l *LanguageServer) EvalWorkspacePath( return EvalPathResult{Value: res, PrintOutput: hook.Output}, nil } -func prepareRegoArgs(query ast.Body, bd bundle.Bundle, printHook print.Hook) []func(*rego.Rego) { - return []func(*rego.Rego){ +func prepareRegoArgs(query ast.Body, bundles map[string]bundle.Bundle, printHook print.Hook) []func(*rego.Rego) { + bundleArgs := make([]func(*rego.Rego), 0, len(bundles)) + for key, b := range bundles { + bundleArgs = append(bundleArgs, rego.ParsedBundle(key, &b)) + } + + baseArgs := []func(*rego.Rego){ rego.ParsedQuery(query), - rego.ParsedBundle("workspace", &bd), rego.Function2(builtins.RegalParseModuleMeta, builtins.RegalParseModule), rego.Function1(builtins.RegalLastMeta, builtins.RegalLast), rego.EnablePrintStatements(true), rego.PrintHook(printHook), } + + return append(baseArgs, bundleArgs...) } type PrintHook struct { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 9edd939e..0a5f6d5f 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -23,6 +23,7 @@ import ( "github.com/styrainc/regal/bundle" rio "github.com/styrainc/regal/internal/io" + "github.com/styrainc/regal/internal/lsp/bundles" "github.com/styrainc/regal/internal/lsp/cache" "github.com/styrainc/regal/internal/lsp/clients" "github.com/styrainc/regal/internal/lsp/commands" @@ -91,8 +92,9 @@ type LanguageServer struct { clientInitializationOptions types.InitializationOptions - cache *cache.Cache - regoStore storage.Store + cache *cache.Cache + regoStore storage.Store + bundleCache *bundles.Cache completionsManager *completions.Manager @@ -1815,7 +1817,14 @@ func (l *LanguageServer) handleInitialize( } if l.workspaceRootURI != "" { - configFile, err := config.FindConfig(uri.ToPath(l.clientIdentifier, l.workspaceRootURI)) + workspaceRootPath := uri.ToPath(l.clientIdentifier, l.workspaceRootURI) + + l.bundleCache = bundles.NewCache(&bundles.CacheOptions{ + WorkspacePath: workspaceRootPath, + ErrorLog: l.errorLog, + }) + + configFile, err := config.FindConfig(workspaceRootPath) if err == nil { l.configWatcher.Watch(configFile.Name()) } @@ -1887,6 +1896,13 @@ func (l *LanguageServer) loadWorkspaceContents(ctx context.Context, newOnly bool return nil, fmt.Errorf("failed to walk workspace dir %q: %w", workspaceRootPath, err) } + if l.bundleCache != nil { + _, err := l.bundleCache.Refresh() + if err != nil { + return nil, fmt.Errorf("failed to refresh the bundle cache: %w", err) + } + } + return changedOrNewURIs, nil }