Skip to content

Commit

Permalink
backport of commit c06d654
Browse files Browse the repository at this point in the history
  • Loading branch information
nathancoleman committed Jul 3, 2024
1 parent c86ebc5 commit 50c904d
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 34 deletions.
4 changes: 2 additions & 2 deletions control-plane/api-gateway/gatekeeper/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ func (g Gatekeeper) initContainer(config common.HelmConfig, name, namespace stri
return corev1.Container{}, fmt.Errorf("error getting namespace metadata for deployment: %s", err)
}

uid, err = ctrlCommon.GetOpenShiftUID(ns)
uid, err = ctrlCommon.GetOpenShiftUID(ns, ctrlCommon.SelectFirstInRange)

if err != nil {
return corev1.Container{}, err
}
group, err = ctrlCommon.GetOpenShiftGroup(ns)
group, err = ctrlCommon.GetOpenShiftGroup(ns, ctrlCommon.SelectFirstInRange)
if err != nil {
return corev1.Container{}, err
}
Expand Down
78 changes: 60 additions & 18 deletions control-plane/connect-inject/common/openshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ import (
"strconv"
"strings"

"github.com/hashicorp/consul-k8s/control-plane/connect-inject/constants"
corev1 "k8s.io/api/core/v1"

"github.com/hashicorp/consul-k8s/control-plane/connect-inject/constants"
)

// GetOpenShiftUID gets the user id from the OpenShift annotation 'openshift.io/sa.scc.uid-range'.
func GetOpenShiftUID(ns *corev1.Namespace) (int64, error) {
// Select the last in the range so we don't conflict with any ID assigned to application containers.
func GetOpenShiftUID(ns *corev1.Namespace, selector idSelector) (int64, error) {
annotation, ok := ns.Annotations[constants.AnnotationOpenShiftUIDRange]
if !ok {
return 0, fmt.Errorf("unable to find annotation %s", constants.AnnotationOpenShiftUIDRange)
Expand All @@ -34,7 +36,7 @@ func GetOpenShiftUID(ns *corev1.Namespace) (int64, error) {
return 0, fmt.Errorf("found annotation %s but it was empty", constants.AnnotationOpenShiftUIDRange)
}

uid, err := parseOpenShiftUID(annotation)
uid, err := parseOpenShiftUID(annotation, selector)
if err != nil {
return 0, err
}
Expand All @@ -45,15 +47,11 @@ func GetOpenShiftUID(ns *corev1.Namespace) (int64, error) {
// parseOpenShiftUID parses the UID "range" from the annotation string. The annotation can either have a '/' or '-'
// as a separator. '-' is the old style of UID from when it used to be an actual range.
// Example annotation value: "1000700000/100000".
func parseOpenShiftUID(val string) (int64, error) {
func parseOpenShiftUID(val string, selector idSelector) (int64, error) {
var uid int64
var err error
if strings.Contains(val, "/") {
str := strings.Split(val, "/")
uid, err = strconv.ParseInt(str[0], 10, 64)
if err != nil {
return 0, err
}
return selectIDInRange(val, selector)
}
if strings.Contains(val, "-") {
str := strings.Split(val, "-")
Expand All @@ -77,7 +75,8 @@ func parseOpenShiftUID(val string) (int64, error) {
// GetOpenShiftGroup gets the group from OpenShift annotation 'openshift.io/sa.scc.supplemental-groups'
// Fall back to the UID annotation if the group annotation does not exist. The values should
// be the same.
func GetOpenShiftGroup(ns *corev1.Namespace) (int64, error) {
// Select the last in the range so we don't conflict with any ID assigned randomly to application containers.
func GetOpenShiftGroup(ns *corev1.Namespace, selector idSelector) (int64, error) {
annotation, ok := ns.Annotations[constants.AnnotationOpenShiftGroups]
if !ok {
// fall back to UID annotation
Expand All @@ -94,25 +93,21 @@ func GetOpenShiftGroup(ns *corev1.Namespace) (int64, error) {
return 0, fmt.Errorf("found annotation %s but it was empty", constants.AnnotationOpenShiftGroups)
}

uid, err := parseOpenShiftGroup(annotation)
gid, err := parseOpenShiftGroup(annotation, selector)
if err != nil {
return 0, err
}

return uid, nil
return gid, nil
}

// parseOpenShiftGroup parses the group from the annotation string. The annotation can either have a '/' or ','
// as a separator. ',' is the old style of UID from when it used to be an actual range.
func parseOpenShiftGroup(val string) (int64, error) {
func parseOpenShiftGroup(val string, selector idSelector) (int64, error) {
var group int64
var err error
if strings.Contains(val, "/") {
str := strings.Split(val, "/")
group, err = strconv.ParseInt(str[0], 10, 64)
if err != nil {
return 0, err
}
return selectIDInRange(val, selector)
}
if strings.Contains(val, ",") {
str := strings.Split(val, ",")
Expand All @@ -128,3 +123,50 @@ func parseOpenShiftGroup(val string) (int64, error) {

return group, nil
}

type idSelector func(values []int64) (int64, error)

var SelectFirstInRange idSelector = func(values []int64) (int64, error) {
if len(values) < 1 {
return 0, fmt.Errorf("range must have at least 1 value")
}
return values[0], nil
}

var SelectSidecarID idSelector = func(values []int64) (int64, error) {
if len(values) < 2 {
return 0, fmt.Errorf("range must have at least 2 values")
}
return values[len(values)-2], nil
}

var SelectInitContainerID idSelector = func(values []int64) (int64, error) {
if len(values) < 1 {
return 0, fmt.Errorf("range must have at least 1 value")
}
return values[len(values)-1], nil
}

func selectIDInRange(value string, selector idSelector) (int64, error) {
parts := strings.Split(value, "/")
if len(parts) != 2 {
return 0, fmt.Errorf("invalid range format: %s", value)
}

start, err := strconv.Atoi(parts[0])
if err != nil {
return 0, fmt.Errorf("invalid range format: %s", parts[0])
}

length, err := strconv.Atoi(parts[1])
if err != nil {
return 0, fmt.Errorf("invalid range format: %s", parts[1])
}

values := make([]int64, length)
for i := 0; i < length; i++ {
values[i] = int64(start + i)
}

return selector(values)
}
7 changes: 4 additions & 3 deletions control-plane/connect-inject/common/openshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"fmt"
"testing"

"github.com/hashicorp/consul-k8s/control-plane/connect-inject/constants"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/hashicorp/consul-k8s/control-plane/connect-inject/constants"
)

func TestOpenShiftUID(t *testing.T) {
Expand Down Expand Up @@ -110,7 +111,7 @@ func TestOpenShiftUID(t *testing.T) {
for _, tt := range cases {
t.Run(tt.Name, func(t *testing.T) {
require := require.New(t)
actual, err := GetOpenShiftUID(tt.Namespace())
actual, err := GetOpenShiftUID(tt.Namespace(), SelectFirstInRange)
if tt.Err == "" {
require.NoError(err)
require.Equal(tt.Expected, actual)
Expand Down Expand Up @@ -224,7 +225,7 @@ func TestOpenShiftGroup(t *testing.T) {
for _, tt := range cases {
t.Run(tt.Name, func(t *testing.T) {
require := require.New(t)
actual, err := GetOpenShiftGroup(tt.Namespace())
actual, err := GetOpenShiftGroup(tt.Namespace(), SelectFirstInRange)
if tt.Err == "" {
require.NoError(err)
require.Equal(tt.Expected, actual)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ func (w *MeshWebhook) consulDataplaneSidecar(namespace corev1.Namespace, pod cor
// Transparent proxy is set in OpenShift. There is an annotation on the namespace that tells us what
// the user and group ids should be for the sidecar.
var err error
uid, err = common.GetOpenShiftUID(&namespace)
uid, err = common.GetOpenShiftUID(&namespace, common.SelectSidecarID)
if err != nil {
return corev1.Container{}, err
}
group, err = common.GetOpenShiftGroup(&namespace)
group, err = common.GetOpenShiftGroup(&namespace, common.SelectSidecarID)
if err != nil {
return corev1.Container{}, err
}
Expand Down
11 changes: 6 additions & 5 deletions control-plane/connect-inject/webhook/container_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import (
"strings"
"text/template"

"github.com/hashicorp/consul-k8s/control-plane/connect-inject/common"
"github.com/hashicorp/consul-k8s/control-plane/connect-inject/constants"
corev1 "k8s.io/api/core/v1"
"k8s.io/utils/pointer"

"github.com/hashicorp/consul-k8s/control-plane/connect-inject/common"
"github.com/hashicorp/consul-k8s/control-plane/connect-inject/constants"
)

const (
Expand Down Expand Up @@ -240,12 +241,12 @@ func (w *MeshWebhook) containerInit(namespace corev1.Namespace, pod corev1.Pod,
if w.EnableOpenShift {
var err error

uid, err = common.GetOpenShiftUID(&namespace)

uid, err = common.GetOpenShiftUID(&namespace, common.SelectInitContainerID)
if err != nil {
return corev1.Container{}, err
}
group, err = common.GetOpenShiftGroup(&namespace)

group, err = common.GetOpenShiftGroup(&namespace, common.SelectInitContainerID)
if err != nil {
return corev1.Container{}, err
}
Expand Down
14 changes: 10 additions & 4 deletions control-plane/connect-inject/webhook/redirect_traffic.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,23 @@ func (w *MeshWebhook) iptablesConfigJSON(pod corev1.Pod, ns corev1.Namespace) (s

if !w.EnableOpenShift {
cfg.ProxyUserID = strconv.Itoa(sidecarUserAndGroupID)

// Add init container user ID to exclude from traffic redirection.
cfg.ExcludeUIDs = append(cfg.ExcludeUIDs, strconv.Itoa(initContainersUserAndGroupID))
} else {
// When using OpenShift, the uid and group are saved as an annotation on the namespace
uid, err := common.GetOpenShiftUID(&ns)
uid, err := common.GetOpenShiftUID(&ns, common.SelectSidecarID)
if err != nil {
return "", err
}
cfg.ProxyUserID = strconv.FormatInt(uid, 10)

// Exclude the user ID for the init container from traffic redirection.
uid, err = common.GetOpenShiftGroup(&ns, common.SelectInitContainerID)
if err != nil {
return "", err
}
cfg.ExcludeUIDs = append(cfg.ExcludeUIDs, strconv.FormatInt(uid, 10))
}

// Set the proxy's inbound port.
Expand Down Expand Up @@ -110,9 +119,6 @@ func (w *MeshWebhook) iptablesConfigJSON(pod corev1.Pod, ns corev1.Namespace) (s
excludeUIDs := splitCommaSeparatedItemsFromAnnotation(constants.AnnotationTProxyExcludeUIDs, pod)
cfg.ExcludeUIDs = append(cfg.ExcludeUIDs, excludeUIDs...)

// Add init container user ID to exclude from traffic redirection.
cfg.ExcludeUIDs = append(cfg.ExcludeUIDs, strconv.Itoa(initContainersUserAndGroupID))

dnsEnabled, err := consulDNSEnabled(ns, pod, w.EnableConsulDNS, w.EnableTransparentProxy)
if err != nil {
return "", err
Expand Down

0 comments on commit 50c904d

Please sign in to comment.