Skip to content

Commit

Permalink
combine swagger file to one and fix CLA
Browse files Browse the repository at this point in the history
  • Loading branch information
cy-zheng authored and achew22 committed May 23, 2018
1 parent 7d77d95 commit e7b79eb
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 72 deletions.
13 changes: 13 additions & 0 deletions protoc-gen-grpc-gateway/descriptor/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type Registry struct {

// allowMerge generation one swagger file out of multiple protos
allowMerge bool

// mergeFileName target swagger file name after merge
mergeFileName string
}

// NewRegistry returns a new Registry.
Expand Down Expand Up @@ -310,6 +313,16 @@ func (r *Registry) IsAllowMerge() bool {
return r.allowMerge
}

// SetMergeFileName controls the target swagger file name out of multiple protos
func (r *Registry) SetMergeFileName(mergeFileName string) {
r.mergeFileName = mergeFileName
}

// GetMergeFileName return the target merge swagger file name
func (r *Registry) GetMergeFileName() string {
return r.mergeFileName
}

// sanitizePackageName replaces unallowed character in package name
// with allowed character.
func sanitizePackageName(pkgName string) string {
Expand Down
75 changes: 62 additions & 13 deletions protoc-gen-swagger/genswagger/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,57 @@ type generator struct {
reg *descriptor.Registry
}

type wrapper struct {
fileName string
swagger *swaggerObject
}

// New returns a new generator which generates grpc gateway files.
func New(reg *descriptor.Registry) gen.Generator {
return &generator{reg: reg}
}

// Merge a lot of swagger file (wrapper) to single one swagger file
func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper {
var mergedTarget *wrapper
for _, f := range targets {
if mergedTarget == nil {
mergedTarget = &wrapper{
fileName: mergeFileName,
swagger: f.swagger,
}
} else {
for k, v := range f.swagger.Definitions {
mergedTarget.swagger.Definitions[k] = v
}
for k, v := range f.swagger.Paths {
mergedTarget.swagger.Paths[k] = v
}
for k, v := range f.swagger.SecurityDefinitions {
mergedTarget.swagger.SecurityDefinitions[k] = v
}
mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...)
}
}
return mergedTarget
}

// convert swagger file obj to plugin.CodeGeneratorResponse_File
func encodeSwagger(file *wrapper) *plugin.CodeGeneratorResponse_File {
var formatted bytes.Buffer
enc := json.NewEncoder(&formatted)
enc.SetIndent("", " ")
enc.Encode(*file.swagger)
name := file.fileName
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
output := fmt.Sprintf("%s.swagger.json", base)
return &plugin.CodeGeneratorResponse_File{
Name: proto.String(output),
Content: proto.String(formatted.String()),
}
}

func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
var files []*plugin.CodeGeneratorResponse_File
if g.reg.IsAllowMerge() {
Expand All @@ -55,29 +101,32 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGenerato
targets = append(targets, mergedTarget)
}

var swaggers []*wrapper
for _, file := range targets {
glog.V(1).Infof("Processing %s", file.GetName())
code, err := applyTemplate(param{File: file, reg: g.reg})
swagger, err := applyTemplate(param{File: file, reg: g.reg})
if err == errNoTargetService {
glog.V(1).Infof("%s: %v", file.GetName(), err)
continue
}
if err != nil {
return nil, err
}

var formatted bytes.Buffer
json.Indent(&formatted, []byte(code), "", " ")

name := file.GetName()
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
output := fmt.Sprintf("%s.swagger.json", base)
files = append(files, &plugin.CodeGeneratorResponse_File{
Name: proto.String(output),
Content: proto.String(formatted.String()),
swaggers = append(swaggers, &wrapper{
fileName: file.GetName(),
swagger: swagger,
})
glog.V(1).Infof("Will emit %s", output)
}

if g.reg.IsAllowMerge() {
targetSwagger := mergeTargetFile(swaggers, g.reg.GetMergeFileName())
files = append(files, encodeSwagger(targetSwagger))
glog.V(1).Infof("New swagger file will emit")
} else {
for _, file := range swaggers {
files = append(files, encodeSwagger(file))
glog.V(1).Infof("New swagger file will emit")
}
}
return files, nil
}
12 changes: 2 additions & 10 deletions protoc-gen-swagger/genswagger/template.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package genswagger

import (
"bytes"
"encoding/json"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -709,7 +707,7 @@ func renderServices(services []*descriptor.Service, paths swaggerPathsObject, re
}

// This function is called with a param which contains the entire definition of a method.
func applyTemplate(p param) (string, error) {
func applyTemplate(p param) (*swaggerObject, error) {
// Create the basic template object. This is the object that everything is
// defined off of.
s := swaggerObject{
Expand Down Expand Up @@ -903,13 +901,7 @@ func applyTemplate(p param) (string, error) {
// should be added here, once supported in the proto.
}

// We now have rendered the entire swagger object. Write the bytes out to a
// string so it can be written to disk.
var w bytes.Buffer
enc := json.NewEncoder(&w)
enc.Encode(&s)

return w.String(), nil
return &s, nil
}

// updateSwaggerDataFromComments updates a Swagger object based on a comment
Expand Down
50 changes: 16 additions & 34 deletions protoc-gen-swagger/genswagger/template_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package genswagger

import (
"encoding/json"
"reflect"
"testing"

Expand All @@ -10,6 +9,7 @@ import (
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
"fmt"
)

func crossLinkFixture(f *descriptor.File) *descriptor.File {
Expand Down Expand Up @@ -253,32 +253,26 @@ func TestApplyTemplateSimple(t *testing.T) {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
got := new(swaggerObject)
err = json.Unmarshal([]byte(result), got)
if err != nil {
t.Errorf("json.Unmarshal(%s) failed with %v; want success", result, err)
return
}
if want, is, name := "2.0", got.Swagger, "Swagger"; !reflect.DeepEqual(is, want) {
if want, is, name := "2.0", result.Swagger, "Swagger"; !reflect.DeepEqual(is, want) {
t.Errorf("applyTemplate(%#v).%s = %s want to be %s", file, name, is, want)
}
if want, is, name := "", got.BasePath, "BasePath"; !reflect.DeepEqual(is, want) {
if want, is, name := "", result.BasePath, "BasePath"; !reflect.DeepEqual(is, want) {
t.Errorf("applyTemplate(%#v).%s = %s want to be %s", file, name, is, want)
}
if want, is, name := []string{"http", "https"}, got.Schemes, "Schemes"; !reflect.DeepEqual(is, want) {
if want, is, name := []string{"http", "https"}, result.Schemes, "Schemes"; !reflect.DeepEqual(is, want) {
t.Errorf("applyTemplate(%#v).%s = %s want to be %s", file, name, is, want)
}
if want, is, name := []string{"application/json"}, got.Consumes, "Consumes"; !reflect.DeepEqual(is, want) {
if want, is, name := []string{"application/json"}, result.Consumes, "Consumes"; !reflect.DeepEqual(is, want) {
t.Errorf("applyTemplate(%#v).%s = %s want to be %s", file, name, is, want)
}
if want, is, name := []string{"application/json"}, got.Produces, "Produces"; !reflect.DeepEqual(is, want) {
if want, is, name := []string{"application/json"}, result.Produces, "Produces"; !reflect.DeepEqual(is, want) {
t.Errorf("applyTemplate(%#v).%s = %s want to be %s", file, name, is, want)
}

// If there was a failure, print out the input and the json result for debugging.
if t.Failed() {
t.Errorf("had: %s", file)
t.Errorf("got: %s", result)
t.Errorf("got: %s", fmt.Sprint(result))
}
}

Expand Down Expand Up @@ -413,35 +407,29 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
var obj swaggerObject
err = json.Unmarshal([]byte(result), &obj)
if err != nil {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
if want, got := "2.0", obj.Swagger; !reflect.DeepEqual(got, want) {
if want, got := "2.0", result.Swagger; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).Swagger = %s want to be %s", file, got, want)
}
if want, got := "", obj.BasePath; !reflect.DeepEqual(got, want) {
if want, got := "", result.BasePath; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).BasePath = %s want to be %s", file, got, want)
}
if want, got := []string{"http", "https"}, obj.Schemes; !reflect.DeepEqual(got, want) {
if want, got := []string{"http", "https"}, result.Schemes; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).Schemes = %s want to be %s", file, got, want)
}
if want, got := []string{"application/json"}, obj.Consumes; !reflect.DeepEqual(got, want) {
if want, got := []string{"application/json"}, result.Consumes; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).Consumes = %s want to be %s", file, got, want)
}
if want, got := []string{"application/json"}, obj.Produces; !reflect.DeepEqual(got, want) {
if want, got := []string{"application/json"}, result.Produces; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).Produces = %s want to be %s", file, got, want)
}
if want, got, name := "Generated for ExampleService.Echo - ", obj.Paths["/v1/echo"].Post.Summary, "Paths[/v1/echo].Post.Summary"; !reflect.DeepEqual(got, want) {
if want, got, name := "Generated for ExampleService.Echo - ", result.Paths["/v1/echo"].Post.Summary, "Paths[/v1/echo].Post.Summary"; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).%s = %s want to be %s", file, name, got, want)
}

// If there was a failure, print out the input and the json result for debugging.
if t.Failed() {
t.Errorf("had: %s", file)
t.Errorf("got: %s", result)
t.Errorf("got: %s", fmt.Sprint(result))
}
}

Expand Down Expand Up @@ -685,22 +673,16 @@ func TestApplyTemplateRequestWithUnusedReferences(t *testing.T) {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}
var obj swaggerObject
err = json.Unmarshal([]byte(result), &obj)
if err != nil {
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
return
}

// Only EmptyMessage must be present, not ExampleMessage
if want, got, name := 1, len(obj.Definitions), "len(Definitions)"; !reflect.DeepEqual(got, want) {
if want, got, name := 1, len(result.Definitions), "len(Definitions)"; !reflect.DeepEqual(got, want) {
t.Errorf("applyTemplate(%#v).%s = %d want to be %d", file, name, got, want)
}

// If there was a failure, print out the input and the json result for debugging.
if t.Failed() {
t.Errorf("had: %s", file)
t.Errorf("got: %s", result)
t.Errorf("got: %s", fmt.Sprint(result))
}
}

Expand Down
10 changes: 9 additions & 1 deletion protoc-gen-swagger/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ var (
file = flag.String("file", "-", "where to load data from")
allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body")
grpcAPIConfiguration = flag.String("grpc_api_configuration", "", "path to gRPC API Configuration in YAML format")
allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body")
allowMerge = flag.Bool("allow_merge", false, "if set, generation one swagger file out of multiple protos")
mergeFileName = flag.String("merge_file_name", "apidocs", "target swagger file name prefix after merge")
)

func main() {
Expand Down Expand Up @@ -55,6 +55,7 @@ func main() {
reg.SetPrefix(*importPrefix)
reg.SetAllowDeleteBody(*allowDeleteBody)
reg.SetAllowMerge(*allowMerge)
reg.SetMergeFileName(*mergeFileName)
for k, v := range pkgMap {
reg.AddPkgMap(k, v)
}
Expand Down Expand Up @@ -126,6 +127,13 @@ func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) erro
}
continue
}
if spec[0] == "allow_merge" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("Cannot set flag %s: %v", p, err)
}
continue
}
err := f.Set(spec[0], "")
if err != nil {
return fmt.Errorf("Cannot set flag %s: %v", p, err)
Expand Down
Loading

0 comments on commit e7b79eb

Please sign in to comment.