Skip to content

Commit

Permalink
Improve the OAS package to handle recursive types properly
Browse files Browse the repository at this point in the history
  • Loading branch information
RussellLuo committed Sep 18, 2021
1 parent a9dd3e3 commit 586f5a6
Show file tree
Hide file tree
Showing 4 changed files with 492 additions and 344 deletions.
351 changes: 14 additions & 337 deletions pkg/oasv2/oasv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ import (
"strconv"
"strings"
"text/template"
"time"

"github.com/RussellLuo/kok/pkg/caseconv"
"github.com/RussellLuo/kok/pkg/codec/httpcodec"
)

var (
Expand Down Expand Up @@ -104,338 +100,12 @@ definitions:
`))
)

type JSONType struct {
Kind string
Type string
Format string
Description string
}

type ItemType JSONType

type Property struct {
Name string
Type JSONType
}

type Definition struct {
Type string
ItemTypeOrProperties interface{}
}

type OASResponse struct {
StatusCode int
SchemaName string
}

type OASResponses struct {
ContentTypes map[string]bool
Success OASResponse
Failures map[int]OASResponse
}

func AddDefinition(defs map[string]Definition, name string, value reflect.Value, anonymous ...bool) {
// NOTE: Use anonymous as a variadic argument for backwards compatibility.
embedded := false
if len(anonymous) > 0 {
embedded = anonymous[0]
}

if _, ok := defs[name]; ok {
// Ignore duplicated definitions implicitly.
return
}

switch value.Kind() {
case reflect.Struct:
addStructDefinition(defs, name, value, embedded)

case reflect.Map:
var properties []Property

valueType := value.Type()
if kind := valueType.Key().Kind(); kind != reflect.String && kind != reflect.Interface {
panic(fmt.Errorf(
"'%s' needs a map with string keys, has '%s' keys",
name, valueType.Key().Kind()))
}

for _, key := range value.MapKeys() {
keyString := key.String()
keyValue := addSubDefinition(defs, keyString, value.MapIndex(key), false)

properties = append(properties, Property{
Name: keyString,
Type: getJSONType(keyValue.Type(), caseconv.ToUpperCamelCase(keyString), ""),
})
}

defs[name] = Definition{
Type: "object",
ItemTypeOrProperties: properties,
}

case reflect.Slice, reflect.Array:
addArrayDefinition(defs, name, value, false)

case reflect.Ptr:
elemType := value.Type().Elem()
elem := reflect.New(elemType).Elem()
AddDefinition(defs, name, elem, embedded) // Always use the input name

default:
panic(fmt.Errorf("unsupported type %s", value.Kind()))
}
}

func addStructDefinition(defs map[string]Definition, name string, value reflect.Value, embedded bool) (properties []Property) {
if isTime(value) {
// Ignore this struct if it is a time value (of type `time.Time`).
return
}

structType := value.Type()
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
fieldName := field.Name
jsonTag := field.Tag.Get("json")
jsonName := strings.SplitN(jsonTag, ",", 2)[0]
if jsonName != "" {
if jsonName == "-" {
continue
}
fieldName = jsonName
}

var fieldValueType reflect.Type

kokField := httpcodec.GetKokField(field)
if kokField.Type != "" {
// Use the user-specified type (a basic type) if any.
var err error
if fieldValueType, err = getReflectType(kokField.Type); err != nil {
panic(err)
}
} else {
// Use the raw type of this struct field.
fieldValue := addSubDefinition(defs, fieldName, value.Field(i), field.Anonymous)
fieldValueType = fieldValue.Type()
}

if field.Anonymous {
// If this is an embedded field, promote the sub-properties of this field.

var subProperties []Property

ft := field.Type
switch k := ft.Kind(); {
case k == reflect.Struct:
v := value.Field(i)
subProperties = addStructDefinition(defs, "", v, field.Anonymous)
case k == reflect.Ptr && ft.Elem().Kind() == reflect.Struct:
v := reflect.New(ft.Elem()).Elem()
subProperties = addStructDefinition(defs, "", v, field.Anonymous)
}

properties = append(properties, subProperties...)
} else {
// Otherwise, append this field as a property.
properties = append(properties, Property{
Name: fieldName,
Type: getJSONType(fieldValueType, caseconv.ToUpperCamelCase(fieldName), kokField.Description),
})
}
}

// Only add non-embedded struct into definitions.
if !embedded {
defs[name] = Definition{
Type: "object",
ItemTypeOrProperties: properties,
}
}

return
}

func addSubDefinition(defs map[string]Definition, name string, value reflect.Value, embedded bool) reflect.Value {
typeName := value.Type().Name()
if typeName == "" {
typeName = caseconv.ToUpperCamelCase(name)
}

switch value.Kind() {
case reflect.Struct:
// We only need to call AddDefinition if this is a non-embedded struct.
// Otherwise, another call to addStructDefinition will be triggered
// instead within addStructDefinition.
if !embedded {
AddDefinition(defs, typeName, value, embedded)
}
case reflect.Map:
AddDefinition(defs, typeName, value, false)
case reflect.Slice, reflect.Array:
addArrayDefinition(defs, typeName, value, true)
case reflect.Ptr:
elemType := value.Type().Elem()
elemName := elemType.Name()
elem := reflect.New(elemType).Elem()
if !isBasicKind(elem.Kind()) {
// This is a pointer to a non-basic type, add more possible definitions.
AddDefinition(defs, elemName, elem, embedded)
}
case reflect.Interface:
value = addSubDefinition(defs, typeName, value.Elem(), embedded)
}

return value
}

func addArrayDefinition(defs map[string]Definition, name string, value reflect.Value, inner bool) {
elemType := value.Type().Elem()
k := elemType.Kind()

if isBasicKind(k) {
if !inner {
defs[name] = Definition{
Type: "array",
ItemTypeOrProperties: getJSONType(elemType, elemType.Name(), ""),
}
}
return
}

switch k {
case reflect.Struct, reflect.Map:
elem := reflect.New(elemType).Elem()
AddDefinition(defs, getArrayElemTypeName(elemType, name), elem)
case reflect.Ptr:
elemType = elemType.Elem()
for elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
elem := reflect.New(elemType).Elem()
AddDefinition(defs, getArrayElemTypeName(elemType, name), elem)
case reflect.Slice, reflect.Array:
elem := reflect.New(elemType).Elem()
addArrayDefinition(defs, getArrayElemTypeName(elemType, name), elem, inner)
default:
panic(fmt.Errorf("only struct slice or array is supported, but got %v", elemType.String()))
}

if !inner {
defs[name] = Definition{
Type: "array",
ItemTypeOrProperties: getArrayElemTypeName(elemType, name),
}
}
}

func getArrayElemTypeName(elemType reflect.Type, arrayTypeName string) string {
return getTypeName(elemType, arrayTypeName+"ArrayItem")
}

func isBasicKind(kind reflect.Kind) bool {
switch kind {
case reflect.Bool, reflect.String,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
return true
default:
return false
}
}

func getJSONType(typ reflect.Type, name, description string) JSONType {
switch typ.Kind() {
case reflect.Bool:
return JSONType{Kind: "basic", Type: "boolean", Description: description}
case reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint8, reflect.Uint16, reflect.Uint32:
return JSONType{Kind: "basic", Type: "integer", Format: "int32", Description: description}
case reflect.Int, reflect.Int64,
reflect.Uint, reflect.Uint64, reflect.Uintptr:
return JSONType{Kind: "basic", Type: "integer", Format: "int64", Description: description}
case reflect.Float32:
return JSONType{Kind: "basic", Type: "number", Format: "float", Description: description}
case reflect.Float64:
return JSONType{Kind: "basic", Type: "number", Format: "double", Description: description}
case reflect.String:
return JSONType{Kind: "basic", Type: "string", Description: description}
case reflect.Struct:
if isTime(reflect.New(typ).Elem()) {
// A time value is also a struct in Go, but it is represented as a string in OAS.
return JSONType{Kind: "basic", Type: "string", Format: "date-time", Description: description}
}
return JSONType{Kind: "object", Type: getTypeName(typ, name), Description: description}
case reflect.Map:
return JSONType{Kind: "object", Type: name, Description: description}
case reflect.Ptr:
// Dereference the pointer and get its element type.
return getJSONType(typ.Elem(), name, description)
case reflect.Slice, reflect.Array:
elemType := typ.Elem()
for elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
return JSONType{Kind: "array", Type: getArrayElemTypeName(elemType, name), Description: description}
default:
panic(fmt.Errorf("unsupported type %s", typ.Kind()))
}
}

func getTypeName(t reflect.Type, defaultName string) string {
if t.Name() != "" {
return t.Name()
func AddDefinition(defs map[string]Definition, name string, value reflect.Value) {
parser := NewParser()
parser.AddDefinition(name, value, false)
for name, def := range parser.Definitions() {
defs[name] = def
}
return defaultName
}

func isTime(v reflect.Value) bool {
switch v.Interface().(type) {
case *time.Time, time.Time:
return true
default:
return false
}
}

func getReflectType(typ string) (reflect.Type, error) {
var v interface{}
switch typ {
case "bool":
v = false
case "string":
v = ""
case "int":
v = int(0)
case "int8":
v = int8(0)
case "int16":
v = int16(0)
case "int32":
v = int32(0)
case "int64":
v = int64(0)
case "uint":
v = uint(0)
case "uint16":
v = uint16(0)
case "uint32":
v = uint32(0)
case "uint64":
v = uint64(0)
case "float32":
v = float32(0)
case "float64":
v = float64(0)
case "time":
v = time.Time{}
default:
return nil, fmt.Errorf("invalid basic type name: %s", typ)
}
return reflect.ValueOf(v).Type(), nil
}

func GetOASResponses(schema Schema, name string, statusCode int, body interface{}) OASResponses {
Expand Down Expand Up @@ -465,13 +135,20 @@ func GetOASResponses(schema Schema, name string, statusCode int, body interface{
}

func AddResponseDefinitions(defs map[string]Definition, schema Schema, name string, statusCode int, body interface{}) {
parser := NewParser()

success := schema.SuccessResponse(name, statusCode, body)
if success.Body != nil {
AddDefinition(defs, name+"Response", reflect.ValueOf(success.Body))
parser.AddDefinition(name+"Response", reflect.ValueOf(success.Body), false)
}

failures := schema.FailureResponses(name)
for _, failure := range failures {
AddDefinition(defs, name+"ResponseError"+strconv.Itoa(failure.StatusCode), reflect.ValueOf(failure.Body))
parser.AddDefinition(name+"ResponseError"+strconv.Itoa(failure.StatusCode), reflect.ValueOf(failure.Body), false)
}

for name, def := range parser.Definitions() {
defs[name] = def
}
}

Expand Down
Loading

0 comments on commit 586f5a6

Please sign in to comment.