Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a helper for decoding protobuf enums #47230

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 11 additions & 55 deletions api/types/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -1118,59 +1118,15 @@ func (r *RequireMFAType) encode() (interface{}, error) {
// decode RequireMFAType from a string or boolean. This is necessary for
// backwards compatibility with the json/yaml tag "require_session_mfa",
// which used to be a boolean.
func (r *RequireMFAType) decode(val interface{}) error {
switch v := val.(type) {
case string:
switch v {
case RequireMFATypeHardwareKeyString:
*r = RequireMFAType_SESSION_AND_HARDWARE_KEY
case RequireMFATypeHardwareKeyTouchString:
*r = RequireMFAType_HARDWARE_KEY_TOUCH
case RequireMFATypeHardwareKeyPINString:
*r = RequireMFAType_HARDWARE_KEY_PIN
case RequireMFATypeHardwareKeyTouchAndPINString:
*r = RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN
case "":
// default to off
*r = RequireMFAType_OFF
default:
// try parsing as a boolean
switch strings.ToLower(v) {
case "yes", "yeah", "y", "true", "1", "on":
*r = RequireMFAType_SESSION
case "no", "nope", "n", "false", "0", "off":
*r = RequireMFAType_OFF
default:
return trace.BadParameter("RequireMFAType invalid value %v", val)
}
}
case bool:
if v {
*r = RequireMFAType_SESSION
} else {
*r = RequireMFAType_OFF
}
case int32:
return trace.Wrap(r.setFromEnum(v))
case int64:
return trace.Wrap(r.setFromEnum(int32(v)))
case int:
return trace.Wrap(r.setFromEnum(int32(v)))
case float64:
return trace.Wrap(r.setFromEnum(int32(v)))
case float32:
return trace.Wrap(r.setFromEnum(int32(v)))
default:
return trace.BadParameter("RequireMFAType invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (r *RequireMFAType) setFromEnum(val int32) error {
if _, ok := RequireMFAType_name[val]; !ok {
return trace.BadParameter("invalid required mfa mode %v", val)
}
*r = RequireMFAType(val)
return nil
func (r *RequireMFAType) decode(val any) error {
err := decodeEnum(r, val, map[any]RequireMFAType{
"": RequireMFAType_OFF, // default to off
false: RequireMFAType_OFF,
true: RequireMFAType_SESSION,
RequireMFATypeHardwareKeyString: RequireMFAType_SESSION_AND_HARDWARE_KEY,
RequireMFATypeHardwareKeyTouchString: RequireMFAType_HARDWARE_KEY_TOUCH,
RequireMFATypeHardwareKeyPINString: RequireMFAType_HARDWARE_KEY_PIN,
RequireMFATypeHardwareKeyTouchAndPINString: RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN,
}, RequireMFAType_name)
return trace.Wrap(err, "failed to decode require mfa type")
}
82 changes: 82 additions & 0 deletions api/types/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright 2024 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package types

import (
"strings"

"github.com/gravitational/trace"
)

// decodeEnum decodes a protobuf enum from a representational value, usually a bool,
// string, or from the actual enum (int32) value. If the value is valid, it is saved
// in the given enum pointer.
func decodeEnum[T ~int32](p *T, val any, representationMap map[any]T, enumMap map[int32]string) error {
if v, ok := representationMap[val]; ok {
*p = v
return nil
}

// try parsing as a bool value
if v, ok := val.(string); ok {
switch strings.ToLower(v) {
case "yes", "yeah", "y", "true", "1", "on":
if v, ok := representationMap[true]; ok {
*p = v
return nil
}
case "no", "nope", "n", "false", "0", "off":
if v, ok := representationMap[false]; ok {
*p = v
return nil
}
}
return trace.BadParameter("unknown enum value %v", val)
}

// parse as enum
var enumVal T
switch v := val.(type) {
case int:
enumVal = T(v)
case int32:
enumVal = T(v)
case int64:
enumVal = T(v)
case float64:
enumVal = T(v)
case float32:
enumVal = T(v)
default:
return trace.BadParameter("unknown enum value %v", val)
}

if err := checkEnum(enumMap, int32(enumVal)); err != nil {
return trace.BadParameter("unknown enum value %v", val)
}

*p = enumVal
return nil
}

// checkEnum checks if the given enum is valid.
func checkEnum(enumMap map[int32]string, val int32) error {
if _, ok := enumMap[val]; ok {
return nil
}
return trace.NotFound("enum %v not found in enum map", val)
}
87 changes: 16 additions & 71 deletions api/types/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -1998,55 +1998,15 @@ func (h CreateHostUserMode) encode() (string, error) {
}

func (h *CreateHostUserMode) decode(val any) error {
var valS string
switch val := val.(type) {
case int32:
return trace.Wrap(h.setFromEnum(val))
case int64:
return trace.Wrap(h.setFromEnum(int32(val)))
case int:
return trace.Wrap(h.setFromEnum(int32(val)))
case float64:
return trace.Wrap(h.setFromEnum(int32(val)))
case float32:
return trace.Wrap(h.setFromEnum(int32(val)))
case string:
valS = val
case bool:
if val {
return trace.BadParameter("create_host_user_mode cannot be true, got %v", val)
}
valS = createHostUserModeOffString
default:
return trace.BadParameter("bad value type %T, expected string or int", val)
}

switch valS {
case "":
*h = CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED
case createHostUserModeOffString:
*h = CreateHostUserMode_HOST_USER_MODE_OFF
case createHostUserModeKeepString:
*h = CreateHostUserMode_HOST_USER_MODE_KEEP
case createHostUserModeInsecureDropString, createHostUserModeDropString:
*h = CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP
default:
return trace.BadParameter("invalid host user mode %v", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (h *CreateHostUserMode) setFromEnum(val int32) error {
// Map drop to insecure-drop
if val == int32(CreateHostUserMode_HOST_USER_MODE_DROP) {
val = int32(CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP)
}
if _, ok := CreateHostUserMode_name[val]; !ok {
return trace.BadParameter("invalid host user mode %v", val)
}
*h = CreateHostUserMode(val)
return nil
err := decodeEnum(h, val, map[interface{}]CreateHostUserMode{
"": CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED,
false: CreateHostUserMode_HOST_USER_MODE_OFF,
createHostUserModeOffString: CreateHostUserMode_HOST_USER_MODE_OFF,
createHostUserModeKeepString: CreateHostUserMode_HOST_USER_MODE_KEEP,
createHostUserModeInsecureDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP,
createHostUserModeDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP,
}, CreateHostUserMode_name)
return trace.Wrap(err, "failed to decode host user mode")
}

// UnmarshalYAML supports parsing CreateHostUserMode from string.
Expand Down Expand Up @@ -2114,28 +2074,13 @@ func (h CreateDatabaseUserMode) encode() (string, error) {
}

func (h *CreateDatabaseUserMode) decode(val any) error {
var str string
switch val := val.(type) {
case string:
str = val
default:
return trace.BadParameter("bad value type %T, expected string", val)
}

switch str {
case "":
*h = CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED
case createDatabaseUserModeOffString:
*h = CreateDatabaseUserMode_DB_USER_MODE_OFF
case createDatabaseUserModeKeepString:
*h = CreateDatabaseUserMode_DB_USER_MODE_KEEP
case createDatabaseUserModeBestEffortDropString:
*h = CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP
default:
return trace.BadParameter("invalid database user mode %v", val)
}

return nil
err := decodeEnum(h, val, map[interface{}]CreateDatabaseUserMode{
"": CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED,
createDatabaseUserModeOffString: CreateDatabaseUserMode_DB_USER_MODE_OFF,
createDatabaseUserModeKeepString: CreateDatabaseUserMode_DB_USER_MODE_KEEP,
createDatabaseUserModeBestEffortDropString: CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP,
}, CreateDatabaseUserMode_name)
return trace.Wrap(err, "failed to decode require mfa type")
}

// UnmarshalYAML supports parsing CreateDatabaseUserMode from string.
Expand Down
Loading