From 5d31793d6048a36ff5fc00c4310088503ca2d7d8 Mon Sep 17 00:00:00 2001 From: Matthew Dolan Date: Wed, 15 Nov 2017 19:10:28 -0800 Subject: [PATCH] Add proto marshaller for proto-over-http (#459) --- .gitignore | 1 + runtime/marshal_proto.go | 62 ++++++++++++++++++++++++ runtime/marshal_proto_test.go | 91 +++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 runtime/marshal_proto.go create mode 100644 runtime/marshal_proto_test.go diff --git a/.gitignore b/.gitignore index 88ddcdf4d4d..eb15433281d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ _output/ +.idea diff --git a/runtime/marshal_proto.go b/runtime/marshal_proto.go new file mode 100644 index 00000000000..f65d1a2676b --- /dev/null +++ b/runtime/marshal_proto.go @@ -0,0 +1,62 @@ +package runtime + +import ( + "io" + + "errors" + "github.com/golang/protobuf/proto" + "io/ioutil" +) + +// ProtoMarshaller is a Marshaller which marshals/unmarshals into/from serialize proto bytes +type ProtoMarshaller struct{} + +// ContentType always returns "application/octet-stream". +func (*ProtoMarshaller) ContentType() string { + return "application/octet-stream" +} + +// Marshal marshals "value" into Proto +func (*ProtoMarshaller) Marshal(value interface{}) ([]byte, error) { + message, ok := value.(proto.Message) + if !ok { + return nil, errors.New("unable to marshal non proto field") + } + return proto.Marshal(message) +} + +// Unmarshal unmarshals proto "data" into "value" +func (*ProtoMarshaller) Unmarshal(data []byte, value interface{}) error { + message, ok := value.(proto.Message) + if !ok { + return errors.New("unable to unmarshal non proto field") + } + return proto.Unmarshal(data, message) +} + +// NewDecoder returns a Decoder which reads proto stream from "reader". +func (marshaller *ProtoMarshaller) NewDecoder(reader io.Reader) Decoder { + return DecoderFunc(func(value interface{}) error { + buffer, err := ioutil.ReadAll(reader) + if err != nil { + return err + } + return marshaller.Unmarshal(buffer, value) + }) +} + +// NewEncoder returns an Encoder which writes proto stream into "writer". +func (marshaller *ProtoMarshaller) NewEncoder(writer io.Writer) Encoder { + return EncoderFunc(func(value interface{}) error { + buffer, err := marshaller.Marshal(value) + if err != nil { + return err + } + _, err = writer.Write(buffer) + if err != nil { + return err + } + + return nil + }) +} diff --git a/runtime/marshal_proto_test.go b/runtime/marshal_proto_test.go new file mode 100644 index 00000000000..07dac47bd5c --- /dev/null +++ b/runtime/marshal_proto_test.go @@ -0,0 +1,91 @@ +package runtime_test + +import ( + "reflect" + "testing" + + "bytes" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" + "github.com/grpc-ecosystem/grpc-gateway/runtime" +) + +var message = &examplepb.ABitOfEverything{ + SingleNested: &examplepb.ABitOfEverything_Nested{}, + RepeatedStringValue: nil, + MappedStringValue: nil, + MappedNestedValue: nil, + RepeatedEnumValue: nil, + TimestampValue: ×tamp.Timestamp{}, + Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + Nested: []*examplepb.ABitOfEverything_Nested{ + { + Name: "foo", + Amount: 12345, + }, + }, + Uint64Value: 0xFFFFFFFFFFFFFFFF, + EnumValue: examplepb.NumericEnum_ONE, + OneofValue: &examplepb.ABitOfEverything_OneofString{ + OneofString: "bar", + }, + MapValue: map[string]examplepb.NumericEnum{ + "a": examplepb.NumericEnum_ONE, + "b": examplepb.NumericEnum_ZERO, + }, +} + +func TestProtoMarshalUnmarshal(t *testing.T) { + marshaller := runtime.ProtoMarshaller{} + + // Marshal + buffer, err := marshaller.Marshal(message) + if err != nil { + t.Fatalf("Marshalling returned error: %s", err.Error()) + } + + // Unmarshal + unmarshalled := &examplepb.ABitOfEverything{} + err = marshaller.Unmarshal(buffer, unmarshalled) + if err != nil { + t.Fatalf("Unmarshalling returned error: %s", err.Error()) + } + + if !reflect.DeepEqual(unmarshalled, message) { + t.Errorf( + "Unmarshalled didn't match original message: (original = %v) != (unmarshalled = %v)", + unmarshalled, + message, + ) + } +} + +func TestProtoEncoderDecodert(t *testing.T) { + marshaller := runtime.ProtoMarshaller{} + + var buf bytes.Buffer + + encoder := marshaller.NewEncoder(&buf) + decoder := marshaller.NewDecoder(&buf) + + // Encode + err := encoder.Encode(message) + if err != nil { + t.Fatalf("Encoding returned error: %s", err.Error()) + } + + // Decode + unencoded := &examplepb.ABitOfEverything{} + err = decoder.Decode(unencoded) + if err != nil { + t.Fatalf("Unmarshalling returned error: %s", err.Error()) + } + + if !reflect.DeepEqual(unencoded, message) { + t.Errorf( + "Unencoded didn't match original message: (original = %v) != (unencoded = %v)", + unencoded, + message, + ) + } +}