From e23784af957ef8ee1530f173164d26fc01351b50 Mon Sep 17 00:00:00 2001 From: micronull Date: Tue, 22 Nov 2022 11:22:16 +0500 Subject: [PATCH 1/3] fix: openapi3.SchemaError message customize (#678) --- openapi3/schema.go | 245 ++++++++++++++---------- openapi3/schema_test.go | 28 +++ openapi3/schema_validation_settings.go | 7 + openapi3filter/options.go | 5 + openapi3filter/validate_request.go | 6 + openapi3filter/validate_request_test.go | 51 +++++ openapi3filter/validate_response.go | 3 + 7 files changed, 239 insertions(+), 106 deletions(-) diff --git a/openapi3/schema.go b/openapi3/schema.go index e5cded877..8e7b8d1ed 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -861,10 +861,11 @@ func (schema *Schema) visitJSON(settings *schemaValidationSettings, value interf } } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "type", - Reason: fmt.Sprintf("unhandled value of type %T", value), + Value: value, + Schema: schema, + SchemaField: "type", + Reason: fmt.Sprintf("unhandled value of type %T", value), + customizeMessageError: settings.customizeMessageError, } } @@ -879,10 +880,11 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "enum", - Reason: "value is not one of the allowed values", + Value: value, + Schema: schema, + SchemaField: "enum", + Reason: "value is not one of the allowed values", + customizeMessageError: settings.customizeMessageError, } } @@ -896,9 +898,10 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "not", + Value: value, + Schema: schema, + SchemaField: "not", + customizeMessageError: settings.customizeMessageError, } } } @@ -961,9 +964,10 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } e := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "oneOf", + Value: value, + Schema: schema, + SchemaField: "oneOf", + customizeMessageError: settings.customizeMessageError, } if ok > 1 { e.Origin = ErrOneOfConflict @@ -1005,9 +1009,10 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "anyOf", + Value: value, + Schema: schema, + SchemaField: "anyOf", + customizeMessageError: settings.customizeMessageError, } } @@ -1024,10 +1029,11 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "allOf", - Origin: err, + Value: value, + Schema: schema, + SchemaField: "allOf", + Origin: err, + customizeMessageError: settings.customizeMessageError, } } } @@ -1042,10 +1048,11 @@ func (schema *Schema) visitJSONNull(settings *schemaValidationSettings) (err err return errSchema } return &SchemaError{ - Value: nil, - Schema: schema, - SchemaField: "nullable", - Reason: "Value is not nullable", + Value: nil, + Schema: schema, + SchemaField: "nullable", + Reason: "Value is not nullable", + customizeMessageError: settings.customizeMessageError, } } @@ -1075,10 +1082,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "type", - Reason: "Value must be an integer", + Value: value, + Schema: schema, + SchemaField: "type", + Reason: "Value must be an integer", + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1110,10 +1118,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "format", - Reason: fmt.Sprintf("number must be an %s", schema.Format), + Value: value, + Schema: schema, + SchemaField: "format", + Reason: fmt.Sprintf("number must be an %s", schema.Format), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1128,10 +1137,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "exclusiveMinimum", - Reason: fmt.Sprintf("number must be more than %g", *schema.Min), + Value: value, + Schema: schema, + SchemaField: "exclusiveMinimum", + Reason: fmt.Sprintf("number must be more than %g", *schema.Min), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1145,10 +1155,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "exclusiveMaximum", - Reason: fmt.Sprintf("number must be less than %g", *schema.Max), + Value: value, + Schema: schema, + SchemaField: "exclusiveMaximum", + Reason: fmt.Sprintf("number must be less than %g", *schema.Max), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1162,10 +1173,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minimum", - Reason: fmt.Sprintf("number must be at least %g", *v), + Value: value, + Schema: schema, + SchemaField: "minimum", + Reason: fmt.Sprintf("number must be at least %g", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1179,10 +1191,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maximum", - Reason: fmt.Sprintf("number must be at most %g", *v), + Value: value, + Schema: schema, + SchemaField: "maximum", + Reason: fmt.Sprintf("number must be at most %g", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1199,9 +1212,10 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "multipleOf", + Value: value, + Schema: schema, + SchemaField: "multipleOf", + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1247,10 +1261,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minLength", - Reason: fmt.Sprintf("minimum string length is %d", minLength), + Value: value, + Schema: schema, + SchemaField: "minLength", + Reason: fmt.Sprintf("minimum string length is %d", minLength), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1262,10 +1277,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maxLength", - Reason: fmt.Sprintf("maximum string length is %d", *maxLength), + Value: value, + Schema: schema, + SchemaField: "maxLength", + Reason: fmt.Sprintf("maximum string length is %d", *maxLength), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1286,10 +1302,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value } if cp := schema.compiledPattern; cp != nil && !cp.MatchString(value) { err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "pattern", - Reason: fmt.Sprintf(`string doesn't match the regular expression "%s"`, schema.Pattern), + Value: value, + Schema: schema, + SchemaField: "pattern", + Reason: fmt.Sprintf(`string doesn't match the regular expression "%s"`, schema.Pattern), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1317,10 +1334,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value } if formatErr != "" { err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "format", - Reason: formatErr, + Value: value, + Schema: schema, + SchemaField: "format", + Reason: formatErr, + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1356,10 +1374,11 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minItems", - Reason: fmt.Sprintf("minimum number of items is %d", v), + Value: value, + Schema: schema, + SchemaField: "minItems", + Reason: fmt.Sprintf("minimum number of items is %d", v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1373,10 +1392,11 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maxItems", - Reason: fmt.Sprintf("maximum number of items is %d", *v), + Value: value, + Schema: schema, + SchemaField: "maxItems", + Reason: fmt.Sprintf("maximum number of items is %d", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1393,10 +1413,11 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "uniqueItems", - Reason: "duplicate items found", + Value: value, + Schema: schema, + SchemaField: "uniqueItems", + Reason: "duplicate items found", + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1484,10 +1505,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minProperties", - Reason: fmt.Sprintf("there must be at least %d properties", v), + Value: value, + Schema: schema, + SchemaField: "minProperties", + Reason: fmt.Sprintf("there must be at least %d properties", v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1501,10 +1523,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maxProperties", - Reason: fmt.Sprintf("there must be at most %d properties", *v), + Value: value, + Schema: schema, + SchemaField: "maxProperties", + Reason: fmt.Sprintf("there must be at most %d properties", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1572,10 +1595,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "properties", - Reason: fmt.Sprintf("property %q is unsupported", k), + Value: value, + Schema: schema, + SchemaField: "properties", + Reason: fmt.Sprintf("property %q is unsupported", k), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1596,10 +1620,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := markSchemaErrorKey(&SchemaError{ - Value: value, - Schema: schema, - SchemaField: "required", - Reason: fmt.Sprintf("property %q is missing", k), + Value: value, + Schema: schema, + SchemaField: "required", + Reason: fmt.Sprintf("property %q is missing", k), + customizeMessageError: settings.customizeMessageError, }, k) if !settings.multiError { return err @@ -1620,10 +1645,11 @@ func (schema *Schema) expectedType(settings *schemaValidationSettings, typ strin return errSchema } return &SchemaError{ - Value: typ, - Schema: schema, - SchemaField: "type", - Reason: "Field must be set to " + schema.Type + " or not be present", + Value: typ, + Schema: schema, + SchemaField: "type", + Reason: "Field must be set to " + schema.Type + " or not be present", + customizeMessageError: settings.customizeMessageError, } } @@ -1639,12 +1665,13 @@ func (schema *Schema) compilePattern() (err error) { } type SchemaError struct { - Value interface{} - reversePath []string - Schema *Schema - SchemaField string - Reason string - Origin error + Value interface{} + reversePath []string + Schema *Schema + SchemaField string + Reason string + Origin error + customizeMessageError func(err *SchemaError) string } var _ interface{ Unwrap() error } = SchemaError{} @@ -1687,6 +1714,12 @@ func (err *SchemaError) JSONPointer() []string { } func (err *SchemaError) Error() string { + if err.customizeMessageError != nil { + if msg := err.customizeMessageError(err); msg != "" { + return msg + } + } + if err.Origin != nil { return err.Origin.Error() } diff --git a/openapi3/schema_test.go b/openapi3/schema_test.go index abec30477..39e4bde52 100644 --- a/openapi3/schema_test.go +++ b/openapi3/schema_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" "math" "reflect" "strings" @@ -1089,6 +1090,33 @@ func TestSchemaErrors(t *testing.T) { } } +func TestSchemaError_CustomMessage(t *testing.T) { + t.Parallel() + + loader := NewLoader() + spc := ` +components: + schemas: + Something: + type: object + properties: + field: + title: Some field + type: string +`[1:] + + doc, err := loader.LoadFromData([]byte(spc)) + require.NoError(t, err) + + opt := SetSchemaErrorCustomMessage(func(err *SchemaError) string { + return fmt.Sprintf("foobar: %s", err.Schema.Title) + }) + + err = doc.Components.Schemas["Something"].Value.Properties["field"].Value.VisitJSON(123, opt) + + require.EqualError(t, err, "foobar: Some field") +} + func testSchemaError(t *testing.T, example schemaErrorExample) func(*testing.T) { return func(t *testing.T) { msg := example.Error.Error() diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 854ae8480..323a2e3a0 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -16,6 +16,8 @@ type schemaValidationSettings struct { onceSettingDefaults sync.Once defaultsSet func() + + customizeMessageError func(err *SchemaError) string } // FailFast returns schema validation errors quicker. @@ -50,6 +52,11 @@ func DefaultsSet(f func()) SchemaValidationOption { return func(s *schemaValidationSettings) { s.defaultsSet = f } } +// SetSchemaErrorCustomMessage allows to override the schema error message. +func SetSchemaErrorCustomMessage(f func(err *SchemaError) string) SchemaValidationOption { + return func(s *schemaValidationSettings) { s.customizeMessageError = f } +} + func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings { settings := &schemaValidationSettings{} for _, opt := range opts { diff --git a/openapi3filter/options.go b/openapi3filter/options.go index 1622339e2..61d5edea3 100644 --- a/openapi3filter/options.go +++ b/openapi3filter/options.go @@ -1,5 +1,7 @@ package openapi3filter +import "github.com/getkin/kin-openapi/openapi3" + // DefaultOptions do not set an AuthenticationFunc. // A spec with security schemes defined will not pass validation // unless an AuthenticationFunc is defined. @@ -21,4 +23,7 @@ type Options struct { // See NoopAuthenticationFunc AuthenticationFunc AuthenticationFunc + + // Sets a function to override the schema error message. + CustomSchemaErrorFunc func(err *openapi3.SchemaError) string } diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index beb47aaad..0ed269076 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -178,6 +178,9 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param opts = make([]openapi3.SchemaValidationOption, 0, 1) opts = append(opts, openapi3.MultiErrors()) } + if options.CustomSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorCustomMessage(options.CustomSchemaErrorFunc)) + } if err = schema.VisitJSON(value, opts...); err != nil { return &RequestError{Input: input, Parameter: parameter, Err: err} } @@ -262,6 +265,9 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } + if options.CustomSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorCustomMessage(options.CustomSchemaErrorFunc)) + } // Validate JSON with the schema if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { diff --git a/openapi3filter/validate_request_test.go b/openapi3filter/validate_request_test.go index 450ee5988..4994f8514 100644 --- a/openapi3filter/validate_request_test.go +++ b/openapi3filter/validate_request_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -215,3 +216,53 @@ components: }) } } + +func TestValidateRequest_CustomSchemaErrorMessage(t *testing.T) { + t.Parallel() + + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /some: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + field: + title: Some field + type: string + responses: + '200': + description: Created +` + router := setupTestRouter(t, spec) + + req, err := http.NewRequest(http.MethodPost, "/some", strings.NewReader(`{"field":123}`)) + require.NoError(t, err) + + req.Header.Add("Content-Type", "application/json") + + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + validationInput := &RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + Options: &Options{ + CustomSchemaErrorFunc: func(err *openapi3.SchemaError) string { + return fmt.Sprintf("foobar: %s", err.Schema.Title) + }, + }, + } + err = ValidateRequest(context.Background(), validationInput) + + require.ErrorContains(t, err, "foobar: Some field") +} diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index e90b5d60e..08214edec 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -66,6 +66,9 @@ func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } + if options.CustomSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorCustomMessage(options.CustomSchemaErrorFunc)) + } headers := make([]string, 0, len(response.Headers)) for k := range response.Headers { From ca0759a4f319a3fff007661f889b538a02d54fbe Mon Sep 17 00:00:00 2001 From: micronull Date: Tue, 22 Nov 2022 16:42:10 +0500 Subject: [PATCH 2/3] set custom schema error func via WithCustomSchemaErrorFunc --- openapi3filter/options.go | 12 ++++- openapi3filter/options_test.go | 83 +++++++++++++++++++++++++++++ openapi3filter/validate_request.go | 8 +-- openapi3filter/validate_response.go | 4 +- 4 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 openapi3filter/options_test.go diff --git a/openapi3filter/options.go b/openapi3filter/options.go index 61d5edea3..aacf04c68 100644 --- a/openapi3filter/options.go +++ b/openapi3filter/options.go @@ -24,6 +24,14 @@ type Options struct { // See NoopAuthenticationFunc AuthenticationFunc AuthenticationFunc - // Sets a function to override the schema error message. - CustomSchemaErrorFunc func(err *openapi3.SchemaError) string + customSchemaErrorFunc CustomSchemaErrorFunc +} + +// CustomSchemaErrorFunc allows for custom the schema error message. +type CustomSchemaErrorFunc func(err *openapi3.SchemaError) string + +// WithCustomSchemaErrorFunc sets a function to override the schema error message. +// If the passed function returns an empty string, it returns to the previous Error() implementation. +func (o *Options) WithCustomSchemaErrorFunc(f CustomSchemaErrorFunc) { + o.customSchemaErrorFunc = f } diff --git a/openapi3filter/options_test.go b/openapi3filter/options_test.go new file mode 100644 index 000000000..12737114d --- /dev/null +++ b/openapi3filter/options_test.go @@ -0,0 +1,83 @@ +package openapi3filter_test + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func ExampleOptions_WithCustomSchemaErrorFunc() { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /some: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + field: + title: Some field + type: integer + responses: + '200': + description: Created +` + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + if err != nil { + panic(err) + } + + err = doc.Validate(loader.Context) + if err != nil { + panic(err) + } + + router, err := gorillamux.NewRouter(doc) + if err != nil { + panic(err) + } + + opts := &openapi3filter.Options{} + + opts.WithCustomSchemaErrorFunc(func(err *openapi3.SchemaError) string { + return fmt.Sprintf(`field "%s" must be an integer`, err.Schema.Title) + }) + + req, err := http.NewRequest(http.MethodPost, "/some", strings.NewReader(`{"field":"not integer"}`)) + if err != nil { + panic(err) + } + + req.Header.Add("Content-Type", "application/json") + + route, pathParams, err := router.FindRoute(req) + if err != nil { + panic(err) + } + + validationInput := &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + Options: opts, + } + err = openapi3filter.ValidateRequest(context.Background(), validationInput) + + fmt.Println(err.Error()) + + // Output: request body has an error: doesn't match the schema: field "Some field" must be an integer +} diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index 0ed269076..6c34cd6f3 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -178,8 +178,8 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param opts = make([]openapi3.SchemaValidationOption, 0, 1) opts = append(opts, openapi3.MultiErrors()) } - if options.CustomSchemaErrorFunc != nil { - opts = append(opts, openapi3.SetSchemaErrorCustomMessage(options.CustomSchemaErrorFunc)) + if options.customSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorMessageCustomizer(options.customSchemaErrorFunc)) } if err = schema.VisitJSON(value, opts...); err != nil { return &RequestError{Input: input, Parameter: parameter, Err: err} @@ -265,8 +265,8 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } - if options.CustomSchemaErrorFunc != nil { - opts = append(opts, openapi3.SetSchemaErrorCustomMessage(options.CustomSchemaErrorFunc)) + if options.customSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorMessageCustomizer(options.customSchemaErrorFunc)) } // Validate JSON with the schema diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index 08214edec..abcbb4e9d 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -66,8 +66,8 @@ func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } - if options.CustomSchemaErrorFunc != nil { - opts = append(opts, openapi3.SetSchemaErrorCustomMessage(options.CustomSchemaErrorFunc)) + if options.customSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorMessageCustomizer(options.customSchemaErrorFunc)) } headers := make([]string, 0, len(response.Headers)) From c8b5b4579f78f3ead6831b45a392f4f02a5baef5 Mon Sep 17 00:00:00 2001 From: micronull Date: Tue, 22 Nov 2022 16:42:59 +0500 Subject: [PATCH 3/3] tests changed to examples and add docs --- openapi3/schema_test.go | 28 ----------- openapi3/schema_validation_settings.go | 5 +- openapi3/schema_validation_settings_test.go | 36 +++++++++++++++ openapi3filter/validate_request_test.go | 51 --------------------- 4 files changed, 39 insertions(+), 81 deletions(-) create mode 100644 openapi3/schema_validation_settings_test.go diff --git a/openapi3/schema_test.go b/openapi3/schema_test.go index 39e4bde52..abec30477 100644 --- a/openapi3/schema_test.go +++ b/openapi3/schema_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "fmt" "math" "reflect" "strings" @@ -1090,33 +1089,6 @@ func TestSchemaErrors(t *testing.T) { } } -func TestSchemaError_CustomMessage(t *testing.T) { - t.Parallel() - - loader := NewLoader() - spc := ` -components: - schemas: - Something: - type: object - properties: - field: - title: Some field - type: string -`[1:] - - doc, err := loader.LoadFromData([]byte(spc)) - require.NoError(t, err) - - opt := SetSchemaErrorCustomMessage(func(err *SchemaError) string { - return fmt.Sprintf("foobar: %s", err.Schema.Title) - }) - - err = doc.Components.Schemas["Something"].Value.Properties["field"].Value.VisitJSON(123, opt) - - require.EqualError(t, err, "foobar: Some field") -} - func testSchemaError(t *testing.T, example schemaErrorExample) func(*testing.T) { return func(t *testing.T) { msg := example.Error.Error() diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 323a2e3a0..5a28c8d8d 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -52,8 +52,9 @@ func DefaultsSet(f func()) SchemaValidationOption { return func(s *schemaValidationSettings) { s.defaultsSet = f } } -// SetSchemaErrorCustomMessage allows to override the schema error message. -func SetSchemaErrorCustomMessage(f func(err *SchemaError) string) SchemaValidationOption { +// SetSchemaErrorMessageCustomizer allows to override the schema error message. +// If the passed function returns an empty string, it returns to the previous Error() implementation. +func SetSchemaErrorMessageCustomizer(f func(err *SchemaError) string) SchemaValidationOption { return func(s *schemaValidationSettings) { s.customizeMessageError = f } } diff --git a/openapi3/schema_validation_settings_test.go b/openapi3/schema_validation_settings_test.go new file mode 100644 index 000000000..db52d3bdf --- /dev/null +++ b/openapi3/schema_validation_settings_test.go @@ -0,0 +1,36 @@ +package openapi3_test + +import ( + "fmt" + + "github.com/getkin/kin-openapi/openapi3" +) + +func ExampleSetSchemaErrorMessageCustomizer() { + loader := openapi3.NewLoader() + spc := ` +components: + schemas: + Something: + type: object + properties: + field: + title: Some field + type: string +`[1:] + + doc, err := loader.LoadFromData([]byte(spc)) + if err != nil { + panic(err) + } + + opt := openapi3.SetSchemaErrorMessageCustomizer(func(err *openapi3.SchemaError) string { + return fmt.Sprintf(`field "%s" should be string`, err.Schema.Title) + }) + + err = doc.Components.Schemas["Something"].Value.Properties["field"].Value.VisitJSON(123, opt) + + fmt.Println(err.Error()) + + // Output: field "Some field" should be string +} diff --git a/openapi3filter/validate_request_test.go b/openapi3filter/validate_request_test.go index 4994f8514..450ee5988 100644 --- a/openapi3filter/validate_request_test.go +++ b/openapi3filter/validate_request_test.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net/http" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -216,53 +215,3 @@ components: }) } } - -func TestValidateRequest_CustomSchemaErrorMessage(t *testing.T) { - t.Parallel() - - const spec = ` -openapi: 3.0.0 -info: - title: 'Validator' - version: 0.0.1 -paths: - /some: - post: - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - field: - title: Some field - type: string - responses: - '200': - description: Created -` - router := setupTestRouter(t, spec) - - req, err := http.NewRequest(http.MethodPost, "/some", strings.NewReader(`{"field":123}`)) - require.NoError(t, err) - - req.Header.Add("Content-Type", "application/json") - - route, pathParams, err := router.FindRoute(req) - require.NoError(t, err) - - validationInput := &RequestValidationInput{ - Request: req, - PathParams: pathParams, - Route: route, - Options: &Options{ - CustomSchemaErrorFunc: func(err *openapi3.SchemaError) string { - return fmt.Sprintf("foobar: %s", err.Schema.Title) - }, - }, - } - err = ValidateRequest(context.Background(), validationInput) - - require.ErrorContains(t, err, "foobar: Some field") -}