diff --git a/manager/controlapi/service.go b/manager/controlapi/service.go index 656533d970..3eb2788ed4 100644 --- a/manager/controlapi/service.go +++ b/manager/controlapi/service.go @@ -149,13 +149,13 @@ func validateEndpointSpec(epSpec *api.EndpointSpec) error { return grpc.Errorf(codes.InvalidArgument, "EndpointSpec: ports can't be used with dnsrr mode") } - portSet := make(map[api.PortConfig]struct{}) + portSet := make(map[uint32]struct{}) for _, port := range epSpec.Ports { - if _, ok := portSet[*port]; ok { - return grpc.Errorf(codes.InvalidArgument, "EndpointSpec: duplicate ports provided") + if _, ok := portSet[port.PublishedPort]; ok { + return grpc.Errorf(codes.InvalidArgument, "EndpointSpec: duplicate published ports provided") } - portSet[*port] = struct{}{} + portSet[port.PublishedPort] = struct{}{} } return nil diff --git a/manager/controlapi/service_test.go b/manager/controlapi/service_test.go index 8833211110..92c912c631 100644 --- a/manager/controlapi/service_test.go +++ b/manager/controlapi/service_test.go @@ -490,14 +490,57 @@ func TestRemoveService(t *testing.T) { } func TestValidateEndpointSpec(t *testing.T) { - err := validateEndpointSpec(&api.EndpointSpec{ + endPointSpec1 := &api.EndpointSpec{ Mode: api.ResolutionModeDNSRoundRobin, Ports: []*api.PortConfig{ { - Name: "http", TargetPort: 80, + Name: "http", + TargetPort: 80, }, }, - }) + } + + endPointSpec2 := &api.EndpointSpec{ + Mode: api.ResolutionModeVirtualIP, + Ports: []*api.PortConfig{ + { + Name: "http", + TargetPort: 81, + PublishedPort: 8001, + }, + { + Name: "http", + TargetPort: 80, + PublishedPort: 8000, + }, + }, + } + + // has duplicated published port, invalid + endPointSpec3 := &api.EndpointSpec{ + Mode: api.ResolutionModeVirtualIP, + Ports: []*api.PortConfig{ + { + Name: "http", + TargetPort: 81, + PublishedPort: 8001, + }, + { + Name: "http", + TargetPort: 80, + PublishedPort: 8001, + }, + }, + } + + err := validateEndpointSpec(endPointSpec1) + assert.Error(t, err) + assert.Equal(t, codes.InvalidArgument, grpc.Code(err)) + + err = validateEndpointSpec(endPointSpec2) + assert.NoError(t, err) + + err = validateEndpointSpec(endPointSpec3) assert.Error(t, err) assert.Equal(t, codes.InvalidArgument, grpc.Code(err)) }