Skip to content

Commit

Permalink
Recurse the current directory to find imports to guess (#202)
Browse files Browse the repository at this point in the history
* Recurse the current directory to find imports to guess

* ignoredPaths doesn't make sense with fs.WalkDir

* Establish a basic depth limit on FS traversal

* Avoid unnecessary traversals by just using path.Match

* Prepend root to dir

* Add tests for tree-sitter directory traversal
  • Loading branch information
blast-hardcheese authored Jan 2, 2024
1 parent f481a68 commit 46f4bc0
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 26 deletions.
6 changes: 3 additions & 3 deletions internal/backends/nodejs/grab.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) {
foundImportPaths := map[string]bool{}

js := javascript.GetLanguage()
jsPkgs, err := util.GuessWithTreeSitter(ctx, dir, js, importsQuery, jsPathGlobs, []string{})
jsPkgs, err := util.GuessWithTreeSitter(ctx, dir, js, importsQuery, jsPathGlobs, nodeIgnorePathSegments)
if err != nil {
return nil, err
}
Expand All @@ -100,7 +100,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) {
}

ts := typescript.GetLanguage()
tsPkgs, err := util.GuessWithTreeSitter(ctx, dir, ts, importsQuery, tsPathGlobs, []string{})
tsPkgs, err := util.GuessWithTreeSitter(ctx, dir, ts, importsQuery, tsPathGlobs, nodeIgnorePathSegments)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +110,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) {
}

tsx := tsx.GetLanguage()
tsxPkgs, err := util.GuessWithTreeSitter(ctx, dir, tsx, importsQuery, tsxPathGlobs, []string{})
tsxPkgs, err := util.GuessWithTreeSitter(ctx, dir, tsx, importsQuery, tsxPathGlobs, nodeIgnorePathSegments)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions internal/backends/nodejs/nodejs.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ var nodejsGuessRegexps = util.Regexps([]string{
`(?m)(?:require|import)\s*\(\s*['"]([^'"{}]+)['"]\s*\)`,
})

var nodeIgnorePathSegments = map[string]bool{
"node_modules": true,
}

var jsPathGlobs = []string{
"*.js",
"*.jsx",
Expand Down
13 changes: 7 additions & 6 deletions internal/backends/python/grab.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ var importsQuery = `
(comment)? @pragma)
`

var pyPathGlobs = []string{"*.py"}
var pyPathSegmentPatterns = []string{"*.py"}

var pyIgnoreGlobs = []string{
"**/__pycache__/**",
"**/venv/**",
"**/.pythonlibs/**",
var pyIgnorePathSegments = map[string]bool{
"__pycache__": true,
"venv": true,
".pythonlibs": true,
".git": true,
}

var internalModules = map[string]bool{
Expand Down Expand Up @@ -263,7 +264,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "python.grab.findImports")
defer span.Finish()
py := python.GetLanguage()
pkgs, err := util.GuessWithTreeSitter(ctx, dir, py, importsQuery, pyPathGlobs, pyIgnoreGlobs)
pkgs, err := util.GuessWithTreeSitter(ctx, dir, py, importsQuery, pyPathSegmentPatterns, pyIgnorePathSegments)

if err != nil {
return nil, err
Expand Down
48 changes: 31 additions & 17 deletions internal/util/tree-sitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,56 @@ type importPragma struct {
Package string
}

const (
// Represents filesystem nodes, including directories.
MaximumVisits = 5000
)

// GuessWithTreeSitter guesses the imports of a directory using tree-sitter.
// For every file in dir that matches a pattern in searchGlobPatterns, but
// not in ignoreGlobPatterns, it will parse the file using lang and queryImports.
// When there's a capture tagged as `@import`, it reports the capture as an import.
// If there's a capture tagged as `@pragma` that's on the same line as an import,
// it will include the pragma in the results.
func GuessWithTreeSitter(ctx context.Context, dir string, lang *sitter.Language, queryImports string, searchGlobPatterns, ignoreGlobPatterns []string) ([]string, error) {
func GuessWithTreeSitter(ctx context.Context, root string, lang *sitter.Language, queryImports string, pathSegmentPatterns []string, ignorePathSegments map[string]bool) ([]string, error) {
//nolint:ineffassign,wastedassign,staticcheck
span, ctx := tracer.StartSpanFromContext(ctx, "GuessWithTreeSitter")
defer span.Finish()
dirFS := os.DirFS(dir)
dirFS := os.DirFS(root)

ignoredPaths := map[string]bool{}
for _, pattern := range ignoreGlobPatterns {
globIgnorePaths, err := fs.Glob(dirFS, pattern)
var visited int
pathsToSearch := []string{}
err := fs.WalkDir(dirFS, ".", func(dir string, d fs.DirEntry, err error) error {
dir = path.Join(root, dir)
if err != nil {
return nil, err
return err
}

for _, gPath := range globIgnorePaths {
ignoredPaths[gPath] = true
visited += 1

// Avoid locking up UPM on pathological project configurations
if visited > MaximumVisits {
return fs.SkipAll
}
}

pathsToSearch := []string{}
for _, pattern := range searchGlobPatterns {
globSearchPaths, err := fs.Glob(dirFS, pattern)
if err != nil {
return nil, err
if ignorePathSegments[d.Name()] {
return fs.SkipDir
}

for _, gPath := range globSearchPaths {
if !ignoredPaths[gPath] {
pathsToSearch = append(pathsToSearch, path.Join(dir, gPath))
for _, pattern := range pathSegmentPatterns {
var ok bool
if ok, err = path.Match(pattern, d.Name()); ok {
pathsToSearch = append(pathsToSearch, dir)
}
if err != nil {
return err
}
}

return nil
})
if err != nil {
return nil, err
}

query, err := sitter.NewQuery([]byte(queryImports), lang)
Expand Down
109 changes: 109 additions & 0 deletions internal/util/tree-sitter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package util

import (
"context"
"os"
"path"
"strings"
"testing"

"github.com/smacker/go-tree-sitter/python"
)

var importsQuery = `
(module
[(import_statement
name: [(dotted_name) @import
(aliased_import
name: (dotted_name) @import)])
(import_from_statement
module_name: (dotted_name) @import)]
.
(comment)? @pragma)
`

func writeFile(dir, name string, contents []byte) error {
err := os.MkdirAll(dir, 0o755)
if err != nil {
return err
}
err = os.WriteFile(path.Join(dir, name), contents, 0644)
if err != nil {
return err
}
return nil
}

func TestTreeSitter(t *testing.T) {
testDir := t.TempDir()

expected := map[string]bool{
"from_root": true,
"from_inner": true,
}

innerDir := path.Join(testDir, "src", "my_module", "inner")
venvDir := path.Join(testDir, "venv", "ignored_module")

var err error
err = writeFile(testDir, "root.py", []byte("import from_root"))
if err != nil {
t.Error(err)
}
err = writeFile(innerDir, "inner.py", []byte("import from_inner"))
if err != nil {
t.Error(err)
}
err = writeFile(venvDir, "ignored.py", []byte("import from_venv"))
if err != nil {
t.Error(err)
}

pathSegmentPatterns := []string{"*.py"}
ignorePathSegments := map[string]bool{
"venv": true,
}

ctx := context.Background()
py := python.GetLanguage()
foundMap := map[string]bool{}
{
var found []string
found, err = GuessWithTreeSitter(ctx, testDir, py, importsQuery, pathSegmentPatterns, ignorePathSegments)
if err != nil {
t.Error(err)
}

for _, pkg := range found {
foundMap[pkg] = true
}
}

for pkg := range foundMap {
if expected[pkg] {
delete(expected, pkg)
delete(foundMap, pkg)
} else {
t.Error("Missing match: ", pkg)
}
}

if len(expected) > 0 {
formatted := []string{}
for pkg := range expected {
formatted = append(formatted, pkg)
}
t.Error("Not all expected checks were passed. Missing:", strings.Join(formatted, ", "))
}

if len(foundMap) > 0 {
formatted := []string{}
for pkg := range foundMap {
formatted = append(formatted, pkg)
}
t.Error("Not all expected checks were passed. Extra:", strings.Join(formatted, ", "))
}
}

0 comments on commit 46f4bc0

Please sign in to comment.