diff --git a/golist.go b/golist.go index b8663abde..ca916a8ea 100644 --- a/golist.go +++ b/golist.go @@ -56,7 +56,7 @@ func (parser *Parser) getAllGoFileInfoFromDepsByList(pkg *build.Package) error { srcDir := pkg.Dir var err error for i := range pkg.GoFiles { - err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.GoFiles[i]), nil) + err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.GoFiles[i]), nil, ParseModels) if err != nil { return err } @@ -64,7 +64,7 @@ func (parser *Parser) getAllGoFileInfoFromDepsByList(pkg *build.Package) error { // parse .go source files that import "C" for i := range pkg.CgoFiles { - err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.CgoFiles[i]), nil) + err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.CgoFiles[i]), nil, ParseModels) if err != nil { return err } diff --git a/packages.go b/packages.go index 9480acb17..ac1a184b2 100644 --- a/packages.go +++ b/packages.go @@ -33,18 +33,18 @@ func NewPackagesDefinitions() *PackagesDefinitions { } // ParseFile parse a source file. -func (pkgDefs *PackagesDefinitions) ParseFile(packageDir, path string, src interface{}) error { +func (pkgDefs *PackagesDefinitions) ParseFile(packageDir, path string, src interface{}, flag ParseFlag) error { // positions are relative to FileSet fileSet := token.NewFileSet() astFile, err := goparser.ParseFile(fileSet, path, src, goparser.ParseComments) if err != nil { return fmt.Errorf("failed to parse file %s, error:%+v", path, err) } - return pkgDefs.collectAstFile(fileSet, packageDir, path, astFile) + return pkgDefs.collectAstFile(fileSet, packageDir, path, astFile, flag) } // collectAstFile collect ast.file. -func (pkgDefs *PackagesDefinitions) collectAstFile(fileSet *token.FileSet, packageDir, path string, astFile *ast.File) error { +func (pkgDefs *PackagesDefinitions) collectAstFile(fileSet *token.FileSet, packageDir, path string, astFile *ast.File, flag ParseFlag) error { if pkgDefs.files == nil { pkgDefs.files = make(map[*ast.File]*AstFileInfo) } @@ -81,13 +81,14 @@ func (pkgDefs *PackagesDefinitions) collectAstFile(fileSet *token.FileSet, packa File: astFile, Path: path, PackagePath: packageDir, + ParseFlag: flag, } return nil } // RangeFiles for range the collection of ast.File in alphabetic order. -func (pkgDefs *PackagesDefinitions) RangeFiles(handle func(filename string, file *ast.File) error) error { +func (pkgDefs *PackagesDefinitions) RangeFiles(handle func(info *AstFileInfo) error) error { sortedFiles := make([]*AstFileInfo, 0, len(pkgDefs.files)) for _, info := range pkgDefs.files { // ignore package path prefix with 'vendor' or $GOROOT, @@ -103,7 +104,7 @@ func (pkgDefs *PackagesDefinitions) RangeFiles(handle func(filename string, file }) for _, info := range sortedFiles { - err := handle(info.Path, info.File) + err := handle(info) if err != nil { return err } diff --git a/packages_test.go b/packages_test.go index 60beca2f8..595659ba1 100644 --- a/packages_test.go +++ b/packages_test.go @@ -13,7 +13,7 @@ import ( func TestPackagesDefinitions_ParseFile(t *testing.T) { pd := PackagesDefinitions{} packageDir := "github.com/swaggo/swag/testdata/simple" - assert.NoError(t, pd.ParseFile(packageDir, "testdata/simple/main.go", nil)) + assert.NoError(t, pd.ParseFile(packageDir, "testdata/simple/main.go", nil, ParseAll)) assert.Equal(t, 1, len(pd.packages)) assert.Equal(t, 1, len(pd.files)) } @@ -21,14 +21,14 @@ func TestPackagesDefinitions_ParseFile(t *testing.T) { func TestPackagesDefinitions_collectAstFile(t *testing.T) { pd := PackagesDefinitions{} fileSet := token.NewFileSet() - assert.NoError(t, pd.collectAstFile(fileSet, "", "", nil)) + assert.NoError(t, pd.collectAstFile(fileSet, "", "", nil, ParseAll)) firstFile := &ast.File{ Name: &ast.Ident{Name: "main.go"}, } packageDir := "github.com/swaggo/swag/testdata/simple" - assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile)) + assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile, ParseAll)) assert.NotEmpty(t, pd.packages[packageDir]) absPath, _ := filepath.Abs("testdata/simple/" + firstFile.Name.String()) @@ -37,18 +37,19 @@ func TestPackagesDefinitions_collectAstFile(t *testing.T) { File: firstFile, Path: absPath, PackagePath: packageDir, + ParseFlag: ParseAll, } assert.Equal(t, pd.files[firstFile], astFileInfo) // Override - assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile)) + assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile, ParseAll)) assert.Equal(t, pd.files[firstFile], astFileInfo) // Another file secondFile := &ast.File{ Name: &ast.Ident{Name: "api.go"}, } - assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile)) + assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile, ParseAll)) } func TestPackagesDefinitions_rangeFiles(t *testing.T) { @@ -72,8 +73,8 @@ func TestPackagesDefinitions_rangeFiles(t *testing.T) { } i, expect := 0, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"} - _ = pd.RangeFiles(func(filename string, file *ast.File) error { - assert.Equal(t, expect[i], filename) + _ = pd.RangeFiles(func(fileInfo *AstFileInfo) error { + assert.Equal(t, expect[i], fileInfo.Path) i++ return nil }) @@ -225,14 +226,14 @@ func TestPackage_rangeFiles(t *testing.T) { } var sorted []string - processor := func(filename string, file *ast.File) error { - sorted = append(sorted, filename) + processor := func(fileInfo *AstFileInfo) error { + sorted = append(sorted, fileInfo.Path) return nil } assert.NoError(t, pd.RangeFiles(processor)) assert.Equal(t, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}, sorted) - assert.Error(t, pd.RangeFiles(func(filename string, file *ast.File) error { + assert.Error(t, pd.RangeFiles(func(fileInfo *AstFileInfo) error { return ErrFuncTypeField })) diff --git a/parser.go b/parser.go index ebc0bf253..f3a7023d6 100644 --- a/parser.go +++ b/parser.go @@ -66,6 +66,20 @@ const ( scopeAttrPrefix = "@scope." ) +// ParseFlag determine what to parse +type ParseFlag int + +const ( + // ParseNone parse nothing + ParseNone ParseFlag = 0x00 + // ParseOperations parse operations + ParseOperations = 0x01 + // ParseModels parse models + ParseModels = 0x02 + // ParseAll parse operations and models + ParseAll = ParseOperations | ParseModels +) + var ( // ErrRecursiveParseStruct recursively parsing struct. ErrRecursiveParseStruct = errors.New("recursively parsing struct") @@ -866,8 +880,11 @@ func matchExtension(extensionToMatch string, comments []*ast.Comment) (match boo } // ParseRouterAPIInfo parses router api info for given astFile. -func (parser *Parser) ParseRouterAPIInfo(fileName string, astFile *ast.File) error { - for _, astDescription := range astFile.Decls { +func (parser *Parser) ParseRouterAPIInfo(fileInfo *AstFileInfo) error { + for _, astDescription := range fileInfo.File.Decls { + if (fileInfo.ParseFlag & ParseOperations) == ParseNone { + continue + } astDeclaration, ok := astDescription.(*ast.FuncDecl) if ok && astDeclaration.Doc != nil && astDeclaration.Doc.List != nil { if parser.matchTags(astDeclaration.Doc.List) && @@ -875,9 +892,9 @@ func (parser *Parser) ParseRouterAPIInfo(fileName string, astFile *ast.File) err // for per 'function' comment, create a new 'Operation' object operation := NewOperation(parser, SetCodeExampleFilesDirectory(parser.codeExampleFilesDir)) for _, comment := range astDeclaration.Doc.List { - err := operation.ParseComment(comment.Text, astFile) + err := operation.ParseComment(comment.Text, fileInfo.File) if err != nil { - return fmt.Errorf("ParseComment error in file %s :%+v", fileName, err) + return fmt.Errorf("ParseComment error in file %s :%+v", fileInfo.Path, err) } } err := processRouterOperation(parser, operation) @@ -1518,7 +1535,7 @@ func (parser *Parser) getAllGoFileInfo(packageDir, searchDir string) error { return err } - return parser.parseFile(filepath.ToSlash(filepath.Dir(filepath.Clean(filepath.Join(packageDir, relPath)))), path, nil) + return parser.parseFile(filepath.ToSlash(filepath.Dir(filepath.Clean(filepath.Join(packageDir, relPath)))), path, nil, ParseAll) }) } @@ -1546,7 +1563,7 @@ func (parser *Parser) getAllGoFileInfoFromDeps(pkg *depth.Pkg) error { } path := filepath.Join(srcDir, f.Name()) - if err := parser.parseFile(pkg.Name, path, nil); err != nil { + if err := parser.parseFile(pkg.Name, path, nil, ParseModels); err != nil { return err } } @@ -1560,12 +1577,12 @@ func (parser *Parser) getAllGoFileInfoFromDeps(pkg *depth.Pkg) error { return nil } -func (parser *Parser) parseFile(packageDir, path string, src interface{}) error { +func (parser *Parser) parseFile(packageDir, path string, src interface{}, flag ParseFlag) error { if strings.HasSuffix(strings.ToLower(path), "_test.go") || filepath.Ext(path) != ".go" { return nil } - return parser.packages.ParseFile(packageDir, path, src) + return parser.packages.ParseFile(packageDir, path, src, flag) } func (parser *Parser) checkOperationIDUniqueness() error { diff --git a/parser_test.go b/parser_test.go index d60ed8c85..041da8ac8 100644 --- a/parser_test.go +++ b/parser_test.go @@ -815,7 +815,7 @@ func Fun() { }` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -2373,7 +2373,7 @@ func Test(){ }` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -2433,9 +2433,9 @@ type ResponseWrapper struct { }` parser := New(SetParseDependency(true)) - _ = parser.packages.ParseFile("api", "api/api.go", src) + _ = parser.packages.ParseFile("api", "api/api.go", src, ParseAll) - _ = parser.packages.ParseFile("rest", "rest/rest.go", restsrc) + _ = parser.packages.ParseFile("rest", "rest/rest.go", restsrc, ParseAll) _, err := parser.packages.ParseTypes() assert.NoError(t, err) @@ -2498,7 +2498,7 @@ func Test(){ } }` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -2626,7 +2626,7 @@ func Test(){ } }` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -2649,12 +2649,12 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) - assert.EqualError(t, err, "ParseComment error in file :unknown accept type can't be accepted") + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.Error(t, err) } func TestParser_ParseRouterApiGet(t *testing.T) { @@ -2667,11 +2667,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2692,11 +2692,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2717,11 +2717,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) - assert.NoError(t, err) p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2742,11 +2742,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2767,11 +2767,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2792,12 +2792,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) - assert.NoError(t, err) p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) - err = p.ParseRouterAPIInfo("", f) - + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2817,11 +2816,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2843,11 +2842,11 @@ package test func Test(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2881,11 +2880,11 @@ func Test2(){ func Test3(){ } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) ps := p.swagger.Paths.Paths @@ -2946,15 +2945,18 @@ func FunctionTwo(w http.ResponseWriter, r *http.Request) { } ` - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + p := New(SetStrict(true)) + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) - p := New(SetStrict(true)) - err = p.ParseRouterAPIInfo("", f) + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.EqualError(t, err, "route GET /api/endpoint is declared multiple times") p = New() - err = p.ParseRouterAPIInfo("", f) + err = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) assert.NoError(t, err) } @@ -3127,7 +3129,7 @@ func Fun() { }` p := New() - err := p.packages.ParseFile("api", "api/api.go", src) + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) _, err = p.packages.ParseTypes() @@ -3182,7 +3184,7 @@ func Fun() { }` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -3218,7 +3220,7 @@ func Fun() { ` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -3260,7 +3262,7 @@ func Fun() { } ` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -3344,7 +3346,7 @@ func Fun() { }` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) _, err := p.packages.ParseTypes() assert.NoError(t, err) @@ -3373,7 +3375,7 @@ func Fun() { pkgs.packages = nil pkgs.files = nil - _ = pkgs.ParseFile("api", "api/api.go", src) + _ = pkgs.ParseFile("api", "api/api.go", src, ParseAll) assert.NotNil(t, pkgs.packages) assert.NotNil(t, pkgs.files) } @@ -3391,7 +3393,7 @@ func Fun() { ` p := New() - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.Equal(t, 1, len(p.packages.files)) var path string var file *ast.File @@ -3402,7 +3404,7 @@ func Fun() { assert.NotNil(t, p.packages.files[file]) // if we collect the same again nothing should happen - _ = p.packages.ParseFile("api", "api/api.go", src) + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.Equal(t, 1, len(p.packages.files)) assert.Equal(t, file, p.packages.packages["api"].Files[path]) assert.NotNil(t, p.packages.files[file]) @@ -3534,7 +3536,7 @@ func Fun() { } ` p := New() - err := p.packages.ParseFile("api", "api/api.go", src) + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) assert.NoError(t, err) _, _ = p.packages.ParseTypes() err = p.packages.RangeFiles(p.ParseRouterAPIInfo) @@ -3880,11 +3882,13 @@ func TestParser_matchTags(t *testing.T) { } func TestParser_parseExtension(t *testing.T) { - - src, err := os.ReadFile("testdata/parseExtension/parseExtension.go") + packagePath := "testdata/parseExtension" + filePath := packagePath + "/parseExtension.go" + src, err := os.ReadFile(filePath) assert.NoError(t, err) - f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + fileSet := token.NewFileSet() + f, err := goparser.ParseFile(fileSet, "", src, goparser.ParseComments) assert.NoError(t, err) tests := []struct { @@ -3911,7 +3915,13 @@ func TestParser_parseExtension(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err = tt.parser.ParseRouterAPIInfo("", f) + err = tt.parser.ParseRouterAPIInfo(&AstFileInfo{ + FileSet: fileSet, + File: f, + Path: filePath, + PackagePath: packagePath, + ParseFlag: ParseAll, + }) assert.NoError(t, err) for p, isExpected := range tt.expectedPaths { _, ok := tt.parser.swagger.Paths.Paths[p] diff --git a/types.go b/types.go index 79c0f6a0d..8e2f51d87 100644 --- a/types.go +++ b/types.go @@ -92,4 +92,7 @@ type AstFileInfo struct { // PackagePath package import path of the ast.File PackagePath string + + // ParseFlag determine what to parse + ParseFlag ParseFlag }