From 25457de41473a8cf0cea4fcdd418e97a8e0728b3 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Wed, 31 Jul 2024 03:15:19 +0900 Subject: [PATCH] Add nil checking for wrapperspb's types (#960) * add nil checking for wrapperspb types --- common/types/pb/type.go | 27 ++++++++++++++++++++++++++ common/types/pb/type_test.go | 36 +++++++++++++++++++++++++++++++++++ common/types/provider_test.go | 9 +++++++++ 3 files changed, 72 insertions(+) diff --git a/common/types/pb/type.go b/common/types/pb/type.go index 6cc95c27..bdd474c9 100644 --- a/common/types/pb/type.go +++ b/common/types/pb/type.go @@ -427,22 +427,49 @@ func unwrap(desc description, msg proto.Message) (any, bool, error) { return structpb.NullValue_NULL_VALUE, true, nil } case *wrapperspb.BoolValue: + if v == nil { + return nil, true, nil + } return v.GetValue(), true, nil case *wrapperspb.BytesValue: + if v == nil { + return nil, true, nil + } return v.GetValue(), true, nil case *wrapperspb.DoubleValue: + if v == nil { + return nil, true, nil + } return v.GetValue(), true, nil case *wrapperspb.FloatValue: + if v == nil { + return nil, true, nil + } return float64(v.GetValue()), true, nil case *wrapperspb.Int32Value: + if v == nil { + return nil, true, nil + } return int64(v.GetValue()), true, nil case *wrapperspb.Int64Value: + if v == nil { + return nil, true, nil + } return v.GetValue(), true, nil case *wrapperspb.StringValue: + if v == nil { + return nil, true, nil + } return v.GetValue(), true, nil case *wrapperspb.UInt32Value: + if v == nil { + return nil, true, nil + } return uint64(v.GetValue()), true, nil case *wrapperspb.UInt64Value: + if v == nil { + return nil, true, nil + } return v.GetValue(), true, nil } return msg, false, nil diff --git a/common/types/pb/type_test.go b/common/types/pb/type_test.go index 5aa36cb1..6c7b8c8f 100644 --- a/common/types/pb/type_test.go +++ b/common/types/pb/type_test.go @@ -354,38 +354,74 @@ func TestTypeDescriptionMaybeUnwrap(t *testing.T) { in: wrapperspb.Bool(true), out: true, }, + { + in: (*wrapperspb.BoolValue)(nil), + out: nil, + }, { in: wrapperspb.Bytes([]byte("hello")), out: []byte("hello"), }, + { + in: (*wrapperspb.BytesValue)(nil), + out: nil, + }, { in: wrapperspb.Double(-4.2), out: -4.2, }, + { + in: (*wrapperspb.DoubleValue)(nil), + out: nil, + }, { in: wrapperspb.Float(4.5), out: 4.5, }, + { + in: (*wrapperspb.FloatValue)(nil), + out: nil, + }, { in: wrapperspb.Int32(123), out: int64(123), }, + { + in: (*wrapperspb.Int32Value)(nil), + out: nil, + }, { in: wrapperspb.Int64(456), out: int64(456), }, + { + in: (*wrapperspb.Int64Value)(nil), + out: nil, + }, { in: wrapperspb.String("goodbye"), out: "goodbye", }, + { + in: (*wrapperspb.StringValue)(nil), + out: nil, + }, { in: wrapperspb.UInt32(1234), out: uint64(1234), }, + { + in: (*wrapperspb.UInt32Value)(nil), + out: nil, + }, { in: wrapperspb.UInt64(5678), out: uint64(5678), }, + { + in: (*wrapperspb.UInt64Value)(nil), + out: nil, + }, { in: tpb.New(time.Unix(12345, 0).UTC()), out: time.Unix(12345, 0).UTC(), diff --git a/common/types/provider_test.go b/common/types/provider_test.go index a2b2026a..2576a574 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -695,22 +695,31 @@ func TestNativeToValue_Wrappers(t *testing.T) { // Wrapper conversion test. expectNativeToValue(t, wrapperspb.Bool(true), True) expectNativeToValue(t, &wrapperspb.BoolValue{}, False) + expectNativeToValue(t, (*wrapperspb.BoolValue)(nil), NullValue) expectNativeToValue(t, &wrapperspb.BytesValue{}, Bytes{}) expectNativeToValue(t, wrapperspb.Bytes([]byte("hi")), Bytes("hi")) + expectNativeToValue(t, (*wrapperspb.BytesValue)(nil), NullValue) expectNativeToValue(t, &wrapperspb.DoubleValue{}, Double(0.0)) expectNativeToValue(t, wrapperspb.Double(6.4), Double(6.4)) + expectNativeToValue(t, (*wrapperspb.DoubleValue)(nil), NullValue) expectNativeToValue(t, &wrapperspb.FloatValue{}, Double(0.0)) expectNativeToValue(t, wrapperspb.Float(3.0), Double(3.0)) + expectNativeToValue(t, (*wrapperspb.FloatValue)(nil), NullValue) expectNativeToValue(t, &wrapperspb.Int32Value{}, IntZero) expectNativeToValue(t, wrapperspb.Int32(-32), Int(-32)) + expectNativeToValue(t, (*wrapperspb.Int32Value)(nil), NullValue) expectNativeToValue(t, &wrapperspb.Int64Value{}, IntZero) expectNativeToValue(t, wrapperspb.Int64(-64), Int(-64)) + expectNativeToValue(t, (*wrapperspb.Int64Value)(nil), NullValue) expectNativeToValue(t, &wrapperspb.StringValue{}, String("")) expectNativeToValue(t, wrapperspb.String("hello"), String("hello")) + expectNativeToValue(t, (*wrapperspb.StringValue)(nil), NullValue) expectNativeToValue(t, &wrapperspb.UInt32Value{}, Uint(0)) expectNativeToValue(t, wrapperspb.UInt32(32), Uint(32)) + expectNativeToValue(t, (*wrapperspb.UInt32Value)(nil), NullValue) expectNativeToValue(t, &wrapperspb.UInt64Value{}, Uint(0)) expectNativeToValue(t, wrapperspb.UInt64(64), Uint(64)) + expectNativeToValue(t, (*wrapperspb.UInt64Value)(nil), NullValue) } func TestNativeToValue_Primitive(t *testing.T) {