Skip to content

Commit

Permalink
Add flag state #1628 (#1629)
Browse files Browse the repository at this point in the history
* add state flag
  • Loading branch information
ivolkoff authored Dec 18, 2023
1 parent 744a58e commit 7603121
Show file tree
Hide file tree
Showing 12 changed files with 1,186 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ OPTIONS:
--tags value, -t value A comma-separated list of tags to filter the APIs for which the documentation is generated.Special case if the tag is prefixed with the '!' character then the APIs with that tag will be excluded
--templateDelims value, --td value Provide custom delimeters for Go template generation. The format is leftDelim,rightDelim. For example: "[[,]]"
--collectionFormat value, --cf value Set default collection format (default: "csv")
--state value Initial state for the state machine (default: ""), @HostState in root file, @State in other files
--help, -h show help (default: false)
```

Expand Down
7 changes: 7 additions & 0 deletions cmd/swag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
packageName = "packageName"
collectionFormatFlag = "collectionFormat"
packagePrefixFlag = "packagePrefix"
stateFlag = "state"
)

var initFlags = []cli.Flag{
Expand Down Expand Up @@ -173,6 +174,11 @@ var initFlags = []cli.Flag{
Value: "",
Usage: "Parse only packages whose import path match the given prefix, comma separated",
},
&cli.StringFlag{
Name: stateFlag,
Value: "",
Usage: "Set host state for swagger.json",
},
}

func initAction(ctx *cli.Context) error {
Expand Down Expand Up @@ -242,6 +248,7 @@ func initAction(ctx *cli.Context) error {
Debugger: logger,
CollectionFormat: collectionFormat,
PackagePrefix: ctx.String(packagePrefixFlag),
State: ctx.String(stateFlag),
})
}

Expand Down
35 changes: 30 additions & 5 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (

"github.com/go-openapi/spec"
"github.com/swaggo/swag"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"sigs.k8s.io/yaml"
)

Expand Down Expand Up @@ -141,6 +143,9 @@ type Config struct {

// Parse only packages whose import path match the given prefix, comma separated
PackagePrefix string

// State set host state
State string
}

// Build builds swagger json file for given searchDir and mainAPIFile. Returns json.
Expand Down Expand Up @@ -207,6 +212,7 @@ func (g *Gen) Build(config *Config) error {
p.ParseVendor = config.ParseVendor
p.ParseInternal = config.ParseInternal
p.RequiredByDefault = config.RequiredByDefault
p.HostState = config.State

if err := p.ParseAPIMultiSearchDir(searchDirs, config.MainAPIFile, config.ParseDepth); err != nil {
return err
Expand Down Expand Up @@ -235,6 +241,10 @@ func (g *Gen) Build(config *Config) error {
func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error {
var filename = "docs.go"

if config.State != "" {
filename = config.State + "_" + filename
}

if config.InstanceName != swag.Name {
filename = config.InstanceName + "_" + filename
}
Expand Down Expand Up @@ -274,6 +284,10 @@ func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {
var filename = "swagger.json"

if config.State != "" {
filename = config.State + "_" + filename
}

if config.InstanceName != swag.Name {
filename = config.InstanceName + "_" + filename
}
Expand All @@ -298,6 +312,10 @@ func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeYAMLSwagger(config *Config, swagger *spec.Swagger) error {
var filename = "swagger.yaml"

if config.State != "" {
filename = config.State + "_" + filename
}

if config.InstanceName != swag.Name {
filename = config.InstanceName + "_" + filename
}
Expand Down Expand Up @@ -441,6 +459,11 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
return err
}

state := ""
if len(config.State) > 0 {
state = cases.Title(language.English).String(strings.ToLower(config.State))
}

buffer := &bytes.Buffer{}

err = generator.Execute(buffer, struct {
Expand All @@ -452,6 +475,7 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
Title string
Description string
Version string
State string
InstanceName string
Schemes []string
GeneratedTime bool
Expand All @@ -468,6 +492,7 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
Title: swagger.Info.Title,
Description: swagger.Info.Description,
Version: swagger.Info.Version,
State: state,
InstanceName: config.InstanceName,
LeftTemplateDelim: config.LeftTemplateDelim,
RightTemplateDelim: config.RightTemplateDelim,
Expand All @@ -489,23 +514,23 @@ package {{.PackageName}}
import "github.com/swaggo/swag"
const docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = ` + "`{{ printDoc .Doc}}`" + `
const docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}{{ .State }} = ` + "`{{ printDoc .Doc}}`" + `
// SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} holds exported Swagger Info so clients can modify it
var SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = &swag.Spec{
// Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} holds exported Swagger Info so clients can modify it
var Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = &swag.Spec{
Version: {{ printf "%q" .Version}},
Host: {{ printf "%q" .Host}},
BasePath: {{ printf "%q" .BasePath}},
Schemes: []string{ {{ range $index, $schema := .Schemes}}{{if gt $index 0}},{{end}}{{printf "%q" $schema}}{{end}} },
Title: {{ printf "%q" .Title}},
Description: {{ printf "%q" .Description}},
InfoInstanceName: {{ printf "%q" .InstanceName }},
SwaggerTemplate: docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }},
SwaggerTemplate: docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}{{ .State }},
LeftDelim: {{ printf "%q" .LeftTemplateDelim}},
RightDelim: {{ printf "%q" .RightTemplateDelim}},
}
func init() {
swag.Register(SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}.InstanceName(), SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }})
swag.Register(Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}.InstanceName(), Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }})
}
`
78 changes: 78 additions & 0 deletions gen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,81 @@ func TestGen_ErrorAndInterface(t *testing.T) {

assert.JSONEq(t, string(expectedJSON), string(jsonOutput))
}

func TestGen_StateAdmin(t *testing.T) {
config := &Config{
SearchDir: "../testdata/state",
MainAPIFile: "./main.go",
OutputDir: "../testdata/state/docs",
OutputTypes: outputTypes,
PropNamingStrategy: "",
State: "admin",
}

assert.NoError(t, New().Build(config))

expectedFiles := []string{
filepath.Join(config.OutputDir, "admin_docs.go"),
filepath.Join(config.OutputDir, "admin_swagger.json"),
filepath.Join(config.OutputDir, "admin_swagger.yaml"),
}
t.Cleanup(func() {
for _, expectedFile := range expectedFiles {
_ = os.Remove(expectedFile)
}
})

// check files
for _, expectedFile := range expectedFiles {
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
require.NoError(t, err)
}
}

// check content
jsonOutput, err := os.ReadFile(filepath.Join(config.OutputDir, "admin_swagger.json"))
require.NoError(t, err)
expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "admin_expected.json"))
require.NoError(t, err)

assert.JSONEq(t, string(expectedJSON), string(jsonOutput))
}

func TestGen_StateUser(t *testing.T) {
config := &Config{
SearchDir: "../testdata/state",
MainAPIFile: "./main.go",
OutputDir: "../testdata/state/docs",
OutputTypes: outputTypes,
PropNamingStrategy: "",
State: "user",
}

assert.NoError(t, New().Build(config))

expectedFiles := []string{
filepath.Join(config.OutputDir, "user_docs.go"),
filepath.Join(config.OutputDir, "user_swagger.json"),
filepath.Join(config.OutputDir, "user_swagger.yaml"),
}
t.Cleanup(func() {
for _, expectedFile := range expectedFiles {
_ = os.Remove(expectedFile)
}
})

// check files
for _, expectedFile := range expectedFiles {
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
require.NoError(t, err)
}
}

// check content
jsonOutput, err := os.ReadFile(filepath.Join(config.OutputDir, "user_swagger.json"))
require.NoError(t, err)
expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "user_expected.json"))
require.NoError(t, err)

assert.JSONEq(t, string(expectedJSON), string(jsonOutput))
}
8 changes: 8 additions & 0 deletions operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Operation struct {
codeExampleFilesDir string
spec.Operation
RouterProperties []RouteProperties
State string
}

var mimeTypeAliases = map[string]string{
Expand Down Expand Up @@ -118,6 +119,8 @@ func (operation *Operation) ParseComment(comment string, astFile *ast.File) erro
lineRemainder = fields[1]
}
switch lowerAttribute {
case stateAttr:
operation.ParseStateComment(lineRemainder)
case descriptionAttr:
operation.ParseDescriptionComment(lineRemainder)
case descriptionMarkdownAttr:
Expand Down Expand Up @@ -183,6 +186,11 @@ func (operation *Operation) ParseCodeSample(attribute, _, lineRemainder string)
return operation.ParseMetadata(attribute, strings.ToLower(attribute), lineRemainder)
}

// ParseDescriptionComment godoc.

Check failure on line 189 in operation.go

View workflow job for this annotation

GitHub Actions / test (1.18.x, ubuntu-latest)

comment on exported method Operation.ParseStateComment should be of the form "ParseStateComment ..."

Check failure on line 189 in operation.go

View workflow job for this annotation

GitHub Actions / test (1.19.x, ubuntu-latest)

comment on exported method Operation.ParseStateComment should be of the form "ParseStateComment ..."

Check failure on line 189 in operation.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, ubuntu-latest)

comment on exported method Operation.ParseStateComment should be of the form "ParseStateComment ..."

Check failure on line 189 in operation.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, ubuntu-latest)

comment on exported method Operation.ParseStateComment should be of the form "ParseStateComment ..."
func (operation *Operation) ParseStateComment(lineRemainder string) {
operation.State = lineRemainder
}

// ParseDescriptionComment godoc.
func (operation *Operation) ParseDescriptionComment(lineRemainder string) {
if operation.Description == "" {
Expand Down
16 changes: 16 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ const (
extDocsURLAttr = "@externaldocs.url"
xCodeSamplesAttr = "@x-codesamples"
scopeAttrPrefix = "@scope."
stateAttr = "@state"
)

// ParseFlag determine what to parse
Expand Down Expand Up @@ -174,6 +175,9 @@ type Parser struct {

// tags to filter the APIs after
tags map[string]struct{}

// HostState is the state of the host
HostState string
}

// FieldParserFactory create FieldParser.
Expand Down Expand Up @@ -541,6 +545,14 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error {

case "@host":
parser.swagger.Host = value
case "@hoststate":
fields = FieldsByAnySpace(commentLine, 3)
if len(fields) != 3 {
return fmt.Errorf("%s needs 3 arguments", attribute)
}
if parser.HostState == fields[1] {
parser.swagger.Host = fields[2]
}
case "@basepath":
parser.swagger.BasePath = value

Expand Down Expand Up @@ -977,6 +989,7 @@ func matchExtension(extensionToMatch string, comments []*ast.Comment) (match boo

// ParseRouterAPIInfo parses router api info for given astFile.
func (parser *Parser) ParseRouterAPIInfo(fileInfo *AstFileInfo) error {
DeclsLoop:
for _, astDescription := range fileInfo.File.Decls {
if (fileInfo.ParseFlag & ParseOperations) == ParseNone {
continue
Expand All @@ -992,6 +1005,9 @@ func (parser *Parser) ParseRouterAPIInfo(fileInfo *AstFileInfo) error {
if err != nil {
return fmt.Errorf("ParseComment error in file %s :%+v", fileInfo.Path, err)
}
if operation.State != "" && operation.State != parser.HostState {
continue DeclsLoop
}
}
err := processRouterOperation(parser, operation)
if err != nil {
Expand Down
Loading

0 comments on commit 7603121

Please sign in to comment.