From fe47dca093cf6e92f3dc43f8a500f67427c7bf4a Mon Sep 17 00:00:00 2001 From: Toly Date: Wed, 26 Jun 2024 18:22:18 +0200 Subject: [PATCH] openapi3: implement circular reference backtracking (#970) * feat(loader): implement reference back-tracking * update docs * address review comments * update docs and readme * fix inconsistency * adjust readme --- .github/docs/openapi3.txt | 2 - README.md | 3 + cmd/validate/main.go | 9 - openapi3/issue570_test.go | 2 +- openapi3/issue615_test.go | 15 -- openapi3/issue796_test.go | 5 - ...oad_cicular_ref_with_external_file_test.go | 14 + openapi3/loader.go | 240 +++++++++--------- openapi3/testdata/circularRef/base.yml | 2 + openapi3/testdata/circularRef/baz.yml | 9 + 10 files changed, 149 insertions(+), 152 deletions(-) create mode 100644 openapi3/testdata/circularRef/baz.yml diff --git a/.github/docs/openapi3.txt b/.github/docs/openapi3.txt index 7c19ece4f..0ada195c8 100644 --- a/.github/docs/openapi3.txt +++ b/.github/docs/openapi3.txt @@ -55,8 +55,6 @@ var ( // ErrSchemaInputInf may be returned when validating a number ErrSchemaInputInf = errors.New("floating point Inf is not allowed") ) -var CircularReferenceCounter = 3 -var CircularReferenceError = "kin-openapi bug found: circular schema reference not handled" var DefaultReadFromURI = URIMapCache(ReadFromURIs(ReadFromHTTP(http.DefaultClient), ReadFromFile)) DefaultReadFromURI returns a caching ReadFromURIFunc which can read remote HTTP URIs and local file URIs. diff --git a/README.md b/README.md index cdf1f8570..d355d0ddd 100644 --- a/README.md +++ b/README.md @@ -303,6 +303,9 @@ for _, path := range doc.Paths.InMatchingOrder() { ## CHANGELOG: Sub-v1 breaking API changes +### v0.126.0 +* `openapi3.CircularReferenceError` and `openapi3.CircularReferenceCounter` are removed. `openapi3.Loader` now implements reference backtracking, so any kind of circular references should be properly resolved. + ### v0.125.0 * The `openapi3filter.ErrFunc` and `openapi3filter.LogFunc` func types now take the validated request's context as first argument. diff --git a/cmd/validate/main.go b/cmd/validate/main.go index a0580967b..874f30c26 100644 --- a/cmd/validate/main.go +++ b/cmd/validate/main.go @@ -12,11 +12,6 @@ import ( "github.com/getkin/kin-openapi/openapi3" ) -var ( - defaultCircular = openapi3.CircularReferenceCounter - circular = flag.Int("circular", defaultCircular, "bump this (upper) limit when there's trouble with cyclic schema references") -) - var ( defaultDefaults = true defaults = flag.Bool("defaults", defaultDefaults, "when false, disables schemas' default field validation") @@ -59,7 +54,6 @@ func main() { switch { case vd.OpenAPI == "3" || strings.HasPrefix(vd.OpenAPI, "3."): - openapi3.CircularReferenceCounter = *circular loader := openapi3.NewLoader() loader.IsExternalRefsAllowed = *ext @@ -90,9 +84,6 @@ func main() { case vd.OpenAPI == "2" || strings.HasPrefix(vd.OpenAPI, "2."), vd.Swagger == "2" || strings.HasPrefix(vd.Swagger, "2."): - if *circular != defaultCircular { - log.Fatal("Flag --circular is only for OpenAPIv3") - } if *defaults != defaultDefaults { log.Fatal("Flag --defaults is only for OpenAPIv3") } diff --git a/openapi3/issue570_test.go b/openapi3/issue570_test.go index 75afb7e3e..1575e5599 100644 --- a/openapi3/issue570_test.go +++ b/openapi3/issue570_test.go @@ -9,5 +9,5 @@ import ( func TestIssue570(t *testing.T) { loader := NewLoader() _, err := loader.LoadFromFile("testdata/issue570.json") - require.ErrorContains(t, err, CircularReferenceError) + require.NoError(t, err) } diff --git a/openapi3/issue615_test.go b/openapi3/issue615_test.go index 67144e9e1..02532bb5a 100644 --- a/openapi3/issue615_test.go +++ b/openapi3/issue615_test.go @@ -9,21 +9,6 @@ import ( ) func TestIssue615(t *testing.T) { - { - var old int - old, openapi3.CircularReferenceCounter = openapi3.CircularReferenceCounter, 1 - defer func() { openapi3.CircularReferenceCounter = old }() - - loader := openapi3.NewLoader() - loader.IsExternalRefsAllowed = true - _, err := loader.LoadFromFile("testdata/recursiveRef/issue615.yml") - require.ErrorContains(t, err, openapi3.CircularReferenceError) - } - - var old int - old, openapi3.CircularReferenceCounter = openapi3.CircularReferenceCounter, 4 - defer func() { openapi3.CircularReferenceCounter = old }() - loader := openapi3.NewLoader() loader.IsExternalRefsAllowed = true doc, err := loader.LoadFromFile("testdata/recursiveRef/issue615.yml") diff --git a/openapi3/issue796_test.go b/openapi3/issue796_test.go index 0900ee5b9..9c8be17f2 100644 --- a/openapi3/issue796_test.go +++ b/openapi3/issue796_test.go @@ -7,11 +7,6 @@ import ( ) func TestIssue796(t *testing.T) { - var old int - // Need to set CircularReferenceCounter to > 10 - old, CircularReferenceCounter = CircularReferenceCounter, 20 - defer func() { CircularReferenceCounter = old }() - loader := NewLoader() doc, err := loader.LoadFromFile("testdata/issue796.yml") require.NoError(t, err) diff --git a/openapi3/load_cicular_ref_with_external_file_test.go b/openapi3/load_cicular_ref_with_external_file_test.go index 7a99e7600..cff5c01fe 100644 --- a/openapi3/load_cicular_ref_with_external_file_test.go +++ b/openapi3/load_cicular_ref_with_external_file_test.go @@ -47,6 +47,19 @@ func TestLoadCircularRefFromFile(t *testing.T) { bar.Value.Properties["foo"] = &openapi3.SchemaRef{Ref: "#/components/schemas/Foo", Value: foo.Value} foo.Value.Properties["bar"] = &openapi3.SchemaRef{Ref: "#/components/schemas/Bar", Value: bar.Value} + bazNestedRef := &openapi3.SchemaRef{Ref: "./baz.yml#/BazNested"} + array := openapi3.NewArraySchema() + array.Items = bazNestedRef + bazNested := &openapi3.Schema{Properties: map[string]*openapi3.SchemaRef{ + "bazArray": { + Value: &openapi3.Schema{ + Items: bazNestedRef, + }, + }, + "baz": bazNestedRef, + }} + bazNestedRef.Value = bazNested + want := &openapi3.T{ OpenAPI: "3.0.3", Info: &openapi3.Info{ @@ -57,6 +70,7 @@ func TestLoadCircularRefFromFile(t *testing.T) { Schemas: openapi3.Schemas{ "Foo": foo, "Bar": bar, + "Baz": bazNestedRef, }, }, } diff --git a/openapi3/loader.go b/openapi3/loader.go index 452d02ef9..381b38bcd 100644 --- a/openapi3/loader.go +++ b/openapi3/loader.go @@ -16,9 +16,6 @@ import ( "strings" ) -var CircularReferenceError = "kin-openapi bug found: circular schema reference not handled" -var CircularReferenceCounter = 3 - func foundUnresolvedRef(ref string) error { return fmt.Errorf("found unresolved ref: %q", ref) } @@ -44,15 +41,9 @@ type Loader struct { visitedDocuments map[string]*T - visitedCallback map[*Callback]struct{} - visitedExample map[*Example]struct{} - visitedHeader map[*Header]struct{} - visitedLink map[*Link]struct{} - visitedParameter map[*Parameter]struct{} - visitedRequestBody map[*RequestBody]struct{} - visitedResponse map[*Response]struct{} - visitedSchema map[*Schema]struct{} - visitedSecurityScheme map[*SecurityScheme]struct{} + visitedRefs map[string]struct{} + visitedPath []string + backtrack map[string][]func(value any) } // NewLoader returns an empty Loader @@ -299,6 +290,34 @@ func isSingleRefElement(ref string) bool { return !strings.Contains(ref, "#") } +func (loader *Loader) visitRef(ref string) { + if loader.visitedRefs == nil { + loader.visitedRefs = make(map[string]struct{}) + loader.backtrack = make(map[string][]func(value any)) + } + loader.visitedPath = append(loader.visitedPath, ref) + loader.visitedRefs[ref] = struct{}{} +} + +func (loader *Loader) unvisitRef(ref string, value any) { + if value != nil { + for _, fn := range loader.backtrack[ref] { + fn(value) + } + } + delete(loader.visitedRefs, ref) + delete(loader.backtrack, ref) + loader.visitedPath = loader.visitedPath[:len(loader.visitedPath)-1] +} + +func (loader *Loader) shouldVisitRef(ref string, fn func(value any)) bool { + if _, ok := loader.visitedRefs[ref]; ok { + loader.backtrack[ref] = append(loader.backtrack[ref], fn) + return false + } + return true +} + func (loader *Loader) resolveComponent(doc *T, ref string, path *url.URL, resolved any) ( componentDoc *T, componentPath *url.URL, @@ -535,17 +554,10 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat return errMUSTHeader } - if component.Value != nil { - if loader.visitedHeader == nil { - loader.visitedHeader = make(map[*Header]struct{}) - } - if _, ok := loader.visitedHeader[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedHeader[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var header Header if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &header); err != nil { @@ -554,8 +566,15 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat component.Value = &header component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Header) + }) { + return nil + } var resolved HeaderRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -587,17 +606,10 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum return errMUSTParameter } - if component.Value != nil { - if loader.visitedParameter == nil { - loader.visitedParameter = make(map[*Parameter]struct{}) - } - if _, ok := loader.visitedParameter[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedParameter[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var param Parameter if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, ¶m); err != nil { @@ -606,8 +618,15 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum component.Value = ¶m component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Parameter) + }) { + return nil + } var resolved ParameterRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -649,17 +668,10 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d return errMUSTRequestBody } - if component.Value != nil { - if loader.visitedRequestBody == nil { - loader.visitedRequestBody = make(map[*RequestBody]struct{}) - } - if _, ok := loader.visitedRequestBody[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedRequestBody[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var requestBody RequestBody if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &requestBody); err != nil { @@ -668,8 +680,15 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d component.Value = &requestBody component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*RequestBody) + }) { + return nil + } var resolved RequestBodyRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -718,17 +737,10 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen return errMUSTResponse } - if component.Value != nil { - if loader.visitedResponse == nil { - loader.visitedResponse = make(map[*Response]struct{}) - } - if _, ok := loader.visitedResponse[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedResponse[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var resp Response if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &resp); err != nil { @@ -737,8 +749,15 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen component.Value = &resp component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Response) + }) { + return nil + } var resolved ResponseRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -798,17 +817,10 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat return errMUSTSchema } - if component.Value != nil { - if loader.visitedSchema == nil { - loader.visitedSchema = make(map[*Schema]struct{}) - } - if _, ok := loader.visitedSchema[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedSchema[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var schema Schema if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &schema); err != nil { @@ -817,14 +829,15 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat component.Value = &schema component.refPath = *documentPath } else { - if visitedLimit(visited, ref) { - visited = append(visited, ref) - return fmt.Errorf("%s with length %d - %s", CircularReferenceError, len(visited), strings.Join(visited, " -> ")) + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Schema) + }) { + return nil } - visited = append(visited, ref) - var resolved SchemaRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -837,10 +850,6 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat component.Value = resolved.Value component.refPath = resolved.refPath } - if loader.visitedSchema == nil { - loader.visitedSchema = make(map[*Schema]struct{}) - } - loader.visitedSchema[component.Value] = struct{}{} } value := component.Value if value == nil { @@ -891,17 +900,10 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme return errMUSTSecurityScheme } - if component.Value != nil { - if loader.visitedSecurityScheme == nil { - loader.visitedSecurityScheme = make(map[*SecurityScheme]struct{}) - } - if _, ok := loader.visitedSecurityScheme[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedSecurityScheme[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var scheme SecurityScheme if _, err = loader.loadSingleElementFromURI(ref, documentPath, &scheme); err != nil { @@ -910,8 +912,15 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme component.Value = &scheme component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*SecurityScheme) + }) { + return nil + } var resolved SecuritySchemeRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -929,21 +938,10 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme } func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentPath *url.URL) (err error) { - if component.isEmpty() { - return errMUSTExample - } - - if component.Value != nil { - if loader.visitedExample == nil { - loader.visitedExample = make(map[*Example]struct{}) - } - if _, ok := loader.visitedExample[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedExample[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var example Example if _, err = loader.loadSingleElementFromURI(ref, documentPath, &example); err != nil { @@ -952,8 +950,15 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP component.Value = &example component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Example) + }) { + return nil + } var resolved ExampleRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -975,17 +980,10 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen return errMUSTCallback } - if component.Value != nil { - if loader.visitedCallback == nil { - loader.visitedCallback = make(map[*Callback]struct{}) - } - if _, ok := loader.visitedCallback[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedCallback[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var resolved Callback if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &resolved); err != nil { @@ -994,8 +992,15 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen component.Value = &resolved component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Callback) + }) { + return nil + } var resolved CallbackRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -1027,17 +1032,10 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u return errMUSTLink } - if component.Value != nil { - if loader.visitedLink == nil { - loader.visitedLink = make(map[*Link]struct{}) - } - if _, ok := loader.visitedLink[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedLink[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var link Link if _, err = loader.loadSingleElementFromURI(ref, documentPath, &link); err != nil { @@ -1046,8 +1044,15 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u component.Value = &link component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Link) + }) { + return nil + } var resolved LinkRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -1081,8 +1086,16 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat } *pathItem = p } else { + if !loader.shouldVisitRef(ref, func(value any) { + *pathItem = *value.(*PathItem) + }) { + return nil + } var resolved PathItem - if doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved); err != nil { + loader.visitRef(ref) + doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, &resolved) + if err != nil { if err == errMUSTPathItem { return nil } @@ -1126,16 +1139,3 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat func unescapeRefString(ref string) string { return strings.Replace(strings.Replace(ref, "~1", "/", -1), "~0", "~", -1) } - -func visitedLimit(visited []string, ref string) bool { - visitedCount := 0 - for _, v := range visited { - if v == ref { - visitedCount++ - if visitedCount >= CircularReferenceCounter { - return true - } - } - } - return false -} diff --git a/openapi3/testdata/circularRef/base.yml b/openapi3/testdata/circularRef/base.yml index ff8240eb0..897a45f37 100644 --- a/openapi3/testdata/circularRef/base.yml +++ b/openapi3/testdata/circularRef/base.yml @@ -14,3 +14,5 @@ components: properties: foo: $ref: "#/components/schemas/Foo" + Baz: + $ref: "./baz.yml#/BazNested" diff --git a/openapi3/testdata/circularRef/baz.yml b/openapi3/testdata/circularRef/baz.yml new file mode 100644 index 000000000..fb8c85420 --- /dev/null +++ b/openapi3/testdata/circularRef/baz.yml @@ -0,0 +1,9 @@ +BazNested: + type: object + properties: + baz: + $ref: "#/BazNested" + bazArray: + type: array + items: + $ref: "#/BazNested"