Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix validation of requests with multipart/form-data schema containing allOf #3

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions openapi3filter/issue722_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package openapi3filter_test

import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"testing"

"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/getkin/kin-openapi/routers/gorillamux"
)

func TestValidateMultipartFormDataContainingAllOf(t *testing.T) {
const spec = `
openapi: 3.0.0
info:
title: 'Validator'
version: 0.0.1
paths:
/test:
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
required:
- file
allOf:
- $ref: '#/components/schemas/Category'
- properties:
file:
type: string
format: binary
description:
type: string
responses:
'200':
description: Created

components:
schemas:
Category:
type: object
properties:
name:
type: string
required:
- name
`

loader := openapi3.NewLoader()
doc, err := loader.LoadFromData([]byte(spec))
if err != nil {
t.Fatal(err)
}
if err = doc.Validate(loader.Context); err != nil {
t.Fatal(err)
}

router, err := gorillamux.NewRouter(doc)
if err != nil {
t.Fatal(err)
}

body := &bytes.Buffer{}
writer := multipart.NewWriter(body)

{ // Add file data
fw, err := writer.CreateFormFile("file", "hello.txt")
if err != nil {
t.Fatal(err)
}
if _, err = io.Copy(fw, strings.NewReader("hello")); err != nil {
t.Fatal(err)
}
}

{ // Add a single "name" item as part data
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", `form-data; name="name"`)
fw, err := writer.CreatePart(h)
if err != nil {
t.Fatal(err)
}
if _, err = io.Copy(fw, strings.NewReader(`foo`)); err != nil {
t.Fatal(err)
}
}

{ // Add a single "discription" item as part data
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", `form-data; name="description"`)
fw, err := writer.CreatePart(h)
if err != nil {
t.Fatal(err)
}
if _, err = io.Copy(fw, strings.NewReader(`description note`)); err != nil {
t.Fatal(err)
}
}

writer.Close()

req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes()))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())

route, pathParams, err := router.FindRoute(req)
if err != nil {
t.Fatal(err)
}

if err = openapi3filter.ValidateRequestBody(
context.Background(),
&openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Route: route,
},
route.Operation.RequestBody.Value,
); err != nil {
t.Error(err)
}
}
78 changes: 53 additions & 25 deletions openapi3filter/req_resp_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1120,33 +1120,47 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S
enc = encFn(name)
}
subEncFn := func(string) *openapi3.Encoding { return enc }
// If the property's schema has type "array" it is means that the form contains a few parts with the same name.
// Every such part has a type that is defined by an items schema in the property's schema.

var valueSchema *openapi3.SchemaRef
var exists bool
valueSchema, exists = schema.Value.Properties[name]
if !exists {
anyProperties := schema.Value.AdditionalPropertiesAllowed
if anyProperties != nil {
switch *anyProperties {
case true:
//additionalProperties: true
continue
default:
//additionalProperties: false
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
if len(schema.Value.AllOf) > 0 {
var exists bool
for _, sr := range schema.Value.AllOf {
valueSchema, exists = sr.Value.Properties[name]
if exists {
break
}
}
if schema.Value.AdditionalProperties == nil {
if !exists {
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
}
valueSchema, exists = schema.Value.AdditionalProperties.Value.Properties[name]
} else {
// If the property's schema has type "array" it is means that the form contains a few parts with the same name.
// Every such part has a type that is defined by an items schema in the property's schema.
var exists bool
valueSchema, exists = schema.Value.Properties[name]
if !exists {
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
anyProperties := schema.Value.AdditionalPropertiesAllowed
if anyProperties != nil {
switch *anyProperties {
case true:
//additionalProperties: true
continue
default:
//additionalProperties: false
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
}
}
if schema.Value.AdditionalProperties == nil {
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
}
valueSchema, exists = schema.Value.AdditionalProperties.Value.Properties[name]
if !exists {
return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)}
}
}
if valueSchema.Value.Type == "array" {
valueSchema = valueSchema.Value.Items
}
}
if valueSchema.Value.Type == "array" {
valueSchema = valueSchema.Value.Items
}

var value interface{}
Expand All @@ -1160,14 +1174,28 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S
}

allTheProperties := make(map[string]*openapi3.SchemaRef)
for k, v := range schema.Value.Properties {
allTheProperties[k] = v
}
if schema.Value.AdditionalProperties != nil {
for k, v := range schema.Value.AdditionalProperties.Value.Properties {
if len(schema.Value.AllOf) > 0 {
for _, sr := range schema.Value.AllOf {
for k, v := range sr.Value.Properties {
allTheProperties[k] = v
}
if sr.Value.AdditionalProperties != nil {
for k, v := range sr.Value.AdditionalProperties.Value.Properties {
allTheProperties[k] = v
}
}
}
} else {
for k, v := range schema.Value.Properties {
allTheProperties[k] = v
}
if schema.Value.AdditionalProperties != nil {
for k, v := range schema.Value.AdditionalProperties.Value.Properties {
allTheProperties[k] = v
}
}
}

// Make an object value from form values.
obj := make(map[string]interface{})
for name, prop := range allTheProperties {
Expand Down