From ffcc8999679373681eb9d419f483714e00778168 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Mon, 9 May 2016 14:07:28 -0400 Subject: [PATCH] custom marshaler: handle `Accept` headers correctly Also greatly simplifies existing logic to avoid special casing of defaults. --- runtime/marshaler_registry.go | 132 ++++++++++------------------- runtime/marshaler_registry_test.go | 21 ++--- runtime/mux.go | 2 +- 3 files changed, 50 insertions(+), 105 deletions(-) diff --git a/runtime/marshaler_registry.go b/runtime/marshaler_registry.go index bc98a6ef1a5..501d37eb3b2 100644 --- a/runtime/marshaler_registry.go +++ b/runtime/marshaler_registry.go @@ -5,12 +5,15 @@ import ( "net/http" ) -const mimeWildcard = "*" +// MIMEWildCard is the fallback MIME type used for requests which do not match +// a registered MIME type. +const MIMEWildcard = "*" var ( - defaultMarshaler = &JSONPb{OrigName: true} - + acceptHeader = http.CanonicalHeaderKey("Accept") contentTypeHeader = http.CanonicalHeaderKey("Content-Type") + + defaultMarshaler = &JSONPb{OrigName: true} ) // MarshalerForRequest returns the inbound/outbound marshalers for this request. @@ -20,117 +23,68 @@ var ( // exactly match in the registry. // Otherwise, it follows the above logic for "*"/InboundMarshaler/OutboundMarshaler. func MarshalerForRequest(mux *ServeMux, r *http.Request) (inbound Marshaler, outbound Marshaler) { - headerVals := append(append([]string(nil), r.Header[contentTypeHeader]...), "*") - - for _, val := range headerVals { - m := mux.marshalers.lookup(val) - if m != nil { - if inbound == nil { - inbound = m.inbound - } - if outbound == nil { - outbound = m.outbound - } + for _, acceptVal := range r.Header[acceptHeader] { + if m, ok := mux.marshalers.mimeMap[acceptVal]; ok { + outbound = m + break } - if inbound != nil && outbound != nil { - // Got them both, return - return inbound, outbound + } + + for _, contentTypeVal := range r.Header[contentTypeHeader] { + if m, ok := mux.marshalers.mimeMap[contentTypeVal]; ok { + inbound = m + break } } + if inbound == nil { - inbound = defaultMarshaler + inbound = mux.marshalers.mimeMap[MIMEWildcard] } if outbound == nil { - outbound = defaultMarshaler + outbound = inbound } - return inbound, outbound + return inbound, outbound } -// marshalerRegistry keeps a mapping from MIME types to mimeMarshalers. -type marshalerRegistry map[string]*mimeMarshaler - -type mimeMarshaler struct { - inbound Marshaler - outbound Marshaler +// marshalerRegistry is a mapping from MIME types to Marshalers. +type marshalerRegistry struct { + mimeMap map[string]Marshaler } -// addMarshaler adds an inbound and outbund marshaler for a case-sensitive MIME type string ("*" to match any MIME type). -// Inbound is the marshaler that is used when marshaling inbound requests from the client. -// Outbound is the marshaler that is used when marshaling outbound responses to the client. -func (r *marshalerRegistry) add(mime string, inbound, outbound Marshaler) error { - if mime == "" { +// add adds a marshaler for a case-sensitive MIME type string ("*" to match any +// MIME type). +func (m marshalerRegistry) add(mime string, marshaler Marshaler) error { + if len(mime) == 0 { return errors.New("empty MIME type") } - (*r)[mime] = &mimeMarshaler{ - inbound: inbound, - outbound: outbound, - } - return nil -} -// addInboundMarshaler adds an inbound marshaler for a case-sensitive MIME type string ("*" to match any MIME type). -// Inbound is the marshaler that is used when marshaling inbound requests from the client. -func (r *marshalerRegistry) addInbound(mime string, inbound Marshaler) error { - if mime == "" { - return errors.New("empty MIME type") - } - if entry := (*r)[mime]; entry != nil { - entry.inbound = inbound - return nil - } - (*r)[mime] = &mimeMarshaler{inbound: inbound} - return nil -} + m.mimeMap[mime] = marshaler -// addOutBound adds an outbund marshaler for a case-sensitive MIME type string ("*" to match any MIME type). -// Outbound is the marshaler that is used when marshaling outbound responses to the client. -func (r *marshalerRegistry) addOutbound(mime string, outbound Marshaler) error { - mime = http.CanonicalHeaderKey(mime) - if mime == "" { - return errors.New("empty MIME type") - } - if entry := (*r)[mime]; entry != nil { - entry.outbound = outbound - return nil - } - (*r)[mime] = &mimeMarshaler{outbound: outbound} return nil - } -func (r *marshalerRegistry) lookup(mime string) *mimeMarshaler { - if r == nil { - return nil +// makeMarshalerMIMERegistry returns a new registry of marshalers. +// It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces. +// +// For example, you could allow the client to specify the use of the runtime.JSONPb marshaler +// with a "applicaton/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler +// with a "application/json" Content-Type. +// "*" can be used to match any Content-Type. +// This can be attached to a ServerMux with the marshaler option. +func makeMarshalerMIMERegistry() marshalerRegistry { + return marshalerRegistry{ + mimeMap: map[string]Marshaler{ + MIMEWildcard: defaultMarshaler, + }, } - return (*r)[mime] } // WithMarshalerOption returns a ServeMuxOption which associates inbound and outbound // Marshalers to a MIME type in mux. -func WithMarshalerOption(mime string, in, out Marshaler) ServeMuxOption { - return func(mux *ServeMux) { - if err := mux.marshalers.add(mime, in, out); err != nil { - panic(err) - } - } -} - -// WithInboundMarshalerOption returns a ServeMuxOption which associates an inbound -// Marshaler to a MIME type in mux. -func WithInboundMarshalerOption(mime string, in Marshaler) ServeMuxOption { - return func(mux *ServeMux) { - if err := mux.marshalers.addInbound(mime, in); err != nil { - panic(err) - } - } -} - -// WithOutboundMarshalerOption returns a ServeMuxOption which associates an outbound -// Marshaler to a MIME type in mux. -func WithOutboundMarshalerOption(mime string, out Marshaler) ServeMuxOption { +func WithMarshalerOption(mime string, marshaler Marshaler) ServeMuxOption { return func(mux *ServeMux) { - if err := mux.marshalers.addOutbound(mime, out); err != nil { + if err := mux.marshalers.add(mime, marshaler); err != nil { panic(err) } } diff --git a/runtime/marshaler_registry_test.go b/runtime/marshaler_registry_test.go index 38a47b462cd..206f6ed9a77 100644 --- a/runtime/marshaler_registry_test.go +++ b/runtime/marshaler_registry_test.go @@ -14,7 +14,8 @@ func TestMarshalerForRequest(t *testing.T) { if err != nil { t.Fatalf(`http.NewRequest("GET", "http://example.com", nil) failed with %v; want success`, err) } - r.Header.Set("Content-Type", "application/x-example") + r.Header.Set("Accept", "application/x-out") + r.Header.Set("Content-Type", "application/x-in") mux := runtime.NewServeMux() @@ -26,7 +27,7 @@ func TestMarshalerForRequest(t *testing.T) { t.Errorf("out = %#v; want a runtime.JSONPb", in) } - var marshalers [6]dummyMarshaler + var marshalers [3]dummyMarshaler specs := []struct { opt runtime.ServeMuxOption @@ -34,30 +35,20 @@ func TestMarshalerForRequest(t *testing.T) { wantOut runtime.Marshaler }{ { - opt: runtime.WithMarshalerOption("*", &marshalers[0], &marshalers[0]), + opt: runtime.WithMarshalerOption(runtime.MIMEWildcard, &marshalers[0]), wantIn: &marshalers[0], wantOut: &marshalers[0], }, { - opt: runtime.WithInboundMarshalerOption("*", &marshalers[1]), + opt: runtime.WithMarshalerOption("application/x-in", &marshalers[1]), wantIn: &marshalers[1], wantOut: &marshalers[0], }, { - opt: runtime.WithOutboundMarshalerOption("application/x-example", &marshalers[2]), + opt: runtime.WithMarshalerOption("application/x-out", &marshalers[2]), wantIn: &marshalers[1], wantOut: &marshalers[2], }, - { - opt: runtime.WithInboundMarshalerOption("application/x-example", &marshalers[3]), - wantIn: &marshalers[3], - wantOut: &marshalers[2], - }, - { - opt: runtime.WithMarshalerOption("application/x-example", &marshalers[4], &marshalers[5]), - wantIn: &marshalers[4], - wantOut: &marshalers[5], - }, } for i, spec := range specs { var opts []runtime.ServeMuxOption diff --git a/runtime/mux.go b/runtime/mux.go index 27bc8da6b51..2e6c5621302 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -41,7 +41,7 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux { serveMux := &ServeMux{ handlers: make(map[string][]handler), forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), - marshalers: make(marshalerRegistry), + marshalers: makeMarshalerMIMERegistry(), } for _, opt := range opts {