Skip to content

Commit

Permalink
Add nil checking for wrapperspb's types (#960)
Browse files Browse the repository at this point in the history
* add nil checking for wrapperspb types
  • Loading branch information
goccy authored Jul 30, 2024
1 parent badfce0 commit 25457de
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
27 changes: 27 additions & 0 deletions common/types/pb/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,22 +427,49 @@ func unwrap(desc description, msg proto.Message) (any, bool, error) {
return structpb.NullValue_NULL_VALUE, true, nil
}
case *wrapperspb.BoolValue:
if v == nil {
return nil, true, nil
}
return v.GetValue(), true, nil
case *wrapperspb.BytesValue:
if v == nil {
return nil, true, nil
}
return v.GetValue(), true, nil
case *wrapperspb.DoubleValue:
if v == nil {
return nil, true, nil
}
return v.GetValue(), true, nil
case *wrapperspb.FloatValue:
if v == nil {
return nil, true, nil
}
return float64(v.GetValue()), true, nil
case *wrapperspb.Int32Value:
if v == nil {
return nil, true, nil
}
return int64(v.GetValue()), true, nil
case *wrapperspb.Int64Value:
if v == nil {
return nil, true, nil
}
return v.GetValue(), true, nil
case *wrapperspb.StringValue:
if v == nil {
return nil, true, nil
}
return v.GetValue(), true, nil
case *wrapperspb.UInt32Value:
if v == nil {
return nil, true, nil
}
return uint64(v.GetValue()), true, nil
case *wrapperspb.UInt64Value:
if v == nil {
return nil, true, nil
}
return v.GetValue(), true, nil
}
return msg, false, nil
Expand Down
36 changes: 36 additions & 0 deletions common/types/pb/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,38 +354,74 @@ func TestTypeDescriptionMaybeUnwrap(t *testing.T) {
in: wrapperspb.Bool(true),
out: true,
},
{
in: (*wrapperspb.BoolValue)(nil),
out: nil,
},
{
in: wrapperspb.Bytes([]byte("hello")),
out: []byte("hello"),
},
{
in: (*wrapperspb.BytesValue)(nil),
out: nil,
},
{
in: wrapperspb.Double(-4.2),
out: -4.2,
},
{
in: (*wrapperspb.DoubleValue)(nil),
out: nil,
},
{
in: wrapperspb.Float(4.5),
out: 4.5,
},
{
in: (*wrapperspb.FloatValue)(nil),
out: nil,
},
{
in: wrapperspb.Int32(123),
out: int64(123),
},
{
in: (*wrapperspb.Int32Value)(nil),
out: nil,
},
{
in: wrapperspb.Int64(456),
out: int64(456),
},
{
in: (*wrapperspb.Int64Value)(nil),
out: nil,
},
{
in: wrapperspb.String("goodbye"),
out: "goodbye",
},
{
in: (*wrapperspb.StringValue)(nil),
out: nil,
},
{
in: wrapperspb.UInt32(1234),
out: uint64(1234),
},
{
in: (*wrapperspb.UInt32Value)(nil),
out: nil,
},
{
in: wrapperspb.UInt64(5678),
out: uint64(5678),
},
{
in: (*wrapperspb.UInt64Value)(nil),
out: nil,
},
{
in: tpb.New(time.Unix(12345, 0).UTC()),
out: time.Unix(12345, 0).UTC(),
Expand Down
9 changes: 9 additions & 0 deletions common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,22 +695,31 @@ func TestNativeToValue_Wrappers(t *testing.T) {
// Wrapper conversion test.
expectNativeToValue(t, wrapperspb.Bool(true), True)
expectNativeToValue(t, &wrapperspb.BoolValue{}, False)
expectNativeToValue(t, (*wrapperspb.BoolValue)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.BytesValue{}, Bytes{})
expectNativeToValue(t, wrapperspb.Bytes([]byte("hi")), Bytes("hi"))
expectNativeToValue(t, (*wrapperspb.BytesValue)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.DoubleValue{}, Double(0.0))
expectNativeToValue(t, wrapperspb.Double(6.4), Double(6.4))
expectNativeToValue(t, (*wrapperspb.DoubleValue)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.FloatValue{}, Double(0.0))
expectNativeToValue(t, wrapperspb.Float(3.0), Double(3.0))
expectNativeToValue(t, (*wrapperspb.FloatValue)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.Int32Value{}, IntZero)
expectNativeToValue(t, wrapperspb.Int32(-32), Int(-32))
expectNativeToValue(t, (*wrapperspb.Int32Value)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.Int64Value{}, IntZero)
expectNativeToValue(t, wrapperspb.Int64(-64), Int(-64))
expectNativeToValue(t, (*wrapperspb.Int64Value)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.StringValue{}, String(""))
expectNativeToValue(t, wrapperspb.String("hello"), String("hello"))
expectNativeToValue(t, (*wrapperspb.StringValue)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.UInt32Value{}, Uint(0))
expectNativeToValue(t, wrapperspb.UInt32(32), Uint(32))
expectNativeToValue(t, (*wrapperspb.UInt32Value)(nil), NullValue)
expectNativeToValue(t, &wrapperspb.UInt64Value{}, Uint(0))
expectNativeToValue(t, wrapperspb.UInt64(64), Uint(64))
expectNativeToValue(t, (*wrapperspb.UInt64Value)(nil), NullValue)
}

func TestNativeToValue_Primitive(t *testing.T) {
Expand Down

0 comments on commit 25457de

Please sign in to comment.