diff --git a/internal/catalog/internal/types/service_endpoints.go b/internal/catalog/internal/types/service_endpoints.go index 8c3d1cf10956..be3475103172 100644 --- a/internal/catalog/internal/types/service_endpoints.go +++ b/internal/catalog/internal/types/service_endpoints.go @@ -44,6 +44,16 @@ func MutateServiceEndpoints(res *pbresource.Resource) error { } } + return nil +} + +func ValidateServiceEndpoints(res *pbresource.Resource) error { + var svcEndpoints pbcatalog.ServiceEndpoints + + if err := res.Data.UnmarshalTo(&svcEndpoints); err != nil { + return resource.NewErrDataParse(&svcEndpoints, err) + } + var err error if !resource.EqualType(res.Owner.Type, ServiceV1Alpha1Type) { err = multierror.Append(err, resource.ErrOwnerTypeInvalid{ @@ -54,6 +64,7 @@ func MutateServiceEndpoints(res *pbresource.Resource) error { if !resource.EqualTenancy(res.Owner.Tenancy, res.Id.Tenancy) { err = multierror.Append(err, resource.ErrOwnerTenantInvalid{ + ResourceType: ServiceEndpointsV1Alpha1Type, ResourceTenancy: res.Id.Tenancy, OwnerTenancy: res.Owner.Tenancy, }) @@ -69,17 +80,6 @@ func MutateServiceEndpoints(res *pbresource.Resource) error { }) } - return err -} - -func ValidateServiceEndpoints(res *pbresource.Resource) error { - var svcEndpoints pbcatalog.ServiceEndpoints - - if err := res.Data.UnmarshalTo(&svcEndpoints); err != nil { - return resource.NewErrDataParse(&svcEndpoints, err) - } - - var err error for idx, endpoint := range svcEndpoints.Endpoints { if endpointErr := validateEndpoint(endpoint, res); endpointErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ diff --git a/internal/catalog/internal/types/service_endpoints_test.go b/internal/catalog/internal/types/service_endpoints_test.go index bd902d624683..25492577d390 100644 --- a/internal/catalog/internal/types/service_endpoints_test.go +++ b/internal/catalog/internal/types/service_endpoints_test.go @@ -7,11 +7,10 @@ import ( "testing" "github.com/hashicorp/consul/internal/resource" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" "github.com/hashicorp/consul/proto-public/pbresource" "github.com/stretchr/testify/require" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/known/anypb" ) var ( @@ -20,22 +19,13 @@ var ( Namespace: "default", PeerName: "local", } -) -func createServiceEndpointsResource(t *testing.T, data protoreflect.ProtoMessage) *pbresource.Resource { - res := &pbresource.Resource{ - Id: &pbresource.ID{ - Type: ServiceEndpointsType, - Tenancy: defaultEndpointTenancy, - Name: "test-service", - }, + badEndpointTenancy = &pbresource.Tenancy{ + Partition: "default", + Namespace: "bad", + PeerName: "local", } - - var err error - res.Data, err = anypb.New(data) - require.NoError(t, err) - return res -} +) func TestValidateServiceEndpoints_Ok(t *testing.T) { data := &pbcatalog.ServiceEndpoints{ @@ -62,8 +52,14 @@ func TestValidateServiceEndpoints_Ok(t *testing.T) { }, } - res := createServiceEndpointsResource(t, data) + res := rtest.Resource(ServiceEndpointsType, "test-service"). + WithData(t, data). + Build() + + // fill in owner automatically + require.NoError(t, MutateServiceEndpoints(res)) + // Now validate that everything is good. err := ValidateServiceEndpoints(res) require.NoError(t, err) } @@ -73,7 +69,7 @@ func TestValidateServiceEndpoints_ParseError(t *testing.T) { // to cause the error we are expecting data := &pbcatalog.IP{Address: "198.18.0.1"} - res := createServiceEndpointsResource(t, data) + res := rtest.Resource(ServiceEndpointsType, "test-service").WithData(t, data).Build() err := ValidateServiceEndpoints(res) require.Error(t, err) @@ -104,6 +100,7 @@ func TestValidateServiceEndpoints_EndpointInvalid(t *testing.T) { } type testCase struct { + owner *pbresource.ID modify func(*pbcatalog.Endpoint) validateErr func(t *testing.T, err error) } @@ -140,11 +137,11 @@ func TestValidateServiceEndpoints_EndpointInvalid(t *testing.T) { } }, validateErr: func(t *testing.T, err error) { - var mapErr resource.ErrInvalidMapKey - require.ErrorAs(t, err, &mapErr) - require.Equal(t, "ports", mapErr.Map) - require.Equal(t, "", mapErr.Key) - require.Equal(t, resource.ErrEmpty, mapErr.Wrapped) + rtest.RequireError(t, err, resource.ErrInvalidMapKey{ + Map: "ports", + Key: "", + Wrapped: resource.ErrEmpty, + }) }, }, "port-0": { @@ -163,18 +160,50 @@ func TestValidateServiceEndpoints_EndpointInvalid(t *testing.T) { require.ErrorIs(t, err, errInvalidPhysicalPort) }, }, + "invalid-owner": { + owner: &pbresource.ID{ + Type: DNSPolicyType, + Tenancy: badEndpointTenancy, + Name: "wrong", + }, + validateErr: func(t *testing.T, err error) { + rtest.RequireError(t, err, resource.ErrOwnerTypeInvalid{ + ResourceType: ServiceEndpointsType, + OwnerType: DNSPolicyType}) + rtest.RequireError(t, err, resource.ErrOwnerTenantInvalid{ + ResourceType: ServiceEndpointsType, + ResourceTenancy: defaultEndpointTenancy, + OwnerTenancy: badEndpointTenancy, + }) + rtest.RequireError(t, err, resource.ErrInvalidField{ + Name: "name", + Wrapped: errInvalidEndpointsOwnerName{ + Name: "test-service", + OwnerName: "wrong"}, + }) + }, + }, } for name, tcase := range cases { t.Run(name, func(t *testing.T) { - data := genData() - tcase.modify(data) + endpoint := genData() + if tcase.modify != nil { + tcase.modify(endpoint) + } - res := createServiceEndpointsResource(t, &pbcatalog.ServiceEndpoints{ + data := &pbcatalog.ServiceEndpoints{ Endpoints: []*pbcatalog.Endpoint{ - data, + endpoint, }, - }) + } + res := rtest.Resource(ServiceEndpointsType, "test-service"). + WithOwner(tcase.owner). + WithData(t, data). + Build() + + // Run the mututation to setup defaults + require.NoError(t, MutateServiceEndpoints(res)) err := ValidateServiceEndpoints(res) require.Error(t, err) @@ -182,3 +211,13 @@ func TestValidateServiceEndpoints_EndpointInvalid(t *testing.T) { }) } } + +func TestMutateServiceEndpoints_PopulateOwner(t *testing.T) { + res := rtest.Resource(ServiceEndpointsType, "test-service").Build() + + require.NoError(t, MutateServiceEndpoints(res)) + require.NotNil(t, res.Owner) + require.True(t, resource.EqualType(res.Owner.Type, ServiceType)) + require.True(t, resource.EqualTenancy(res.Owner.Tenancy, defaultEndpointTenancy)) + require.Equal(t, res.Owner.Name, res.Id.Name) +} diff --git a/internal/resource/errors.go b/internal/resource/errors.go index c258f9ad35b1..8eaf9e2259c1 100644 --- a/internal/resource/errors.go +++ b/internal/resource/errors.go @@ -4,7 +4,6 @@ package resource import ( - "errors" "fmt" "github.com/hashicorp/consul/proto-public/pbresource" @@ -12,11 +11,33 @@ import ( ) var ( - ErrMissing = errors.New("missing required field") - ErrEmpty = errors.New("cannot be empty") - ErrReferenceTenancyNotEqual = errors.New("resource tenancy and reference tenancy differ") + ErrMissing = NewConstError("missing required field") + ErrEmpty = NewConstError("cannot be empty") + ErrReferenceTenancyNotEqual = NewConstError("resource tenancy and reference tenancy differ") ) +// ConstError is more or less equivalent to the stdlib errors.errorstring. However, having +// our own exported type allows us to more accurately compare error values in tests. +// +// - go-cmp will not compared unexported fields by default. +// - cmp.AllowUnexported() requires a concrete struct type and due to the stdlib not +// exporting the errorstring type there doesn't seem to be a way to get at the type. +// - cmpopts.EquateErrors has issues with protobuf types within other error structs. +// +// Due to these factors the easiest thing to do is to create a custom comparer for +// the ConstError type and use it where necessary. +type ConstError struct { + message string +} + +func NewConstError(msg string) ConstError { + return ConstError{message: msg} +} + +func (e ConstError) Error() string { + return e.message +} + type ErrDataParse struct { TypeName string Wrapped error diff --git a/internal/resource/resourcetest/builder.go b/internal/resource/resourcetest/builder.go index 38f1a6e3ec4a..11e82c07839c 100644 --- a/internal/resource/resourcetest/builder.go +++ b/internal/resource/resourcetest/builder.go @@ -55,6 +55,11 @@ func ResourceID(id *pbresource.ID) *resourceBuilder { } } +func (b *resourceBuilder) WithTenancy(tenant *pbresource.Tenancy) *resourceBuilder { + b.resource.Id.Tenancy = tenant + return b +} + func (b *resourceBuilder) WithData(t T, data protoreflect.ProtoMessage) *resourceBuilder { t.Helper() diff --git a/internal/resource/resourcetest/require.go b/internal/resource/resourcetest/require.go index fff8cb2aebf2..8e102398fa87 100644 --- a/internal/resource/resourcetest/require.go +++ b/internal/resource/resourcetest/require.go @@ -2,12 +2,38 @@ package resourcetest import ( "github.com/google/go-cmp/cmp" + "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto/private/prototest" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" ) +// CompareErrorString is a helper to generate a custom go-cmp comparer method +// that will perform an equality check on the error message. This is mainly +// useful to get around not being able to see unexported data within errors. +func CompareErrorString[T error]() cmp.Option { + return cmp.Comparer(func(e1, e2 T) bool { + return e1.Error() == e2.Error() + }) +} + +// default comparers for known types that don't play well with go-cmp +var comparers = []cmp.Option{ + CompareErrorString[resource.ConstError](), +} + +// RequireError is useful for asserting that some chained multierror contains a specific error. +func RequireError[E error](t T, err error, expected E, opts ...cmp.Option) { + t.Helper() + + var actual E + require.ErrorAs(t, err, &actual) + + opts = append(opts, comparers...) + prototest.AssertDeepEqual(t, expected, actual, opts...) +} + func RequireVersionUnchanged(t T, res *pbresource.Resource, version string) { t.Helper() require.Equal(t, version, res.Version)