diff --git a/api/types/authentication.go b/api/types/authentication.go index 6a0d7f40a453..a224a0c104bd 100644 --- a/api/types/authentication.go +++ b/api/types/authentication.go @@ -1118,59 +1118,15 @@ func (r *RequireMFAType) encode() (interface{}, error) { // decode RequireMFAType from a string or boolean. This is necessary for // backwards compatibility with the json/yaml tag "require_session_mfa", // which used to be a boolean. -func (r *RequireMFAType) decode(val interface{}) error { - switch v := val.(type) { - case string: - switch v { - case RequireMFATypeHardwareKeyString: - *r = RequireMFAType_SESSION_AND_HARDWARE_KEY - case RequireMFATypeHardwareKeyTouchString: - *r = RequireMFAType_HARDWARE_KEY_TOUCH - case RequireMFATypeHardwareKeyPINString: - *r = RequireMFAType_HARDWARE_KEY_PIN - case RequireMFATypeHardwareKeyTouchAndPINString: - *r = RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN - case "": - // default to off - *r = RequireMFAType_OFF - default: - // try parsing as a boolean - switch strings.ToLower(v) { - case "yes", "yeah", "y", "true", "1", "on": - *r = RequireMFAType_SESSION - case "no", "nope", "n", "false", "0", "off": - *r = RequireMFAType_OFF - default: - return trace.BadParameter("RequireMFAType invalid value %v", val) - } - } - case bool: - if v { - *r = RequireMFAType_SESSION - } else { - *r = RequireMFAType_OFF - } - case int32: - return trace.Wrap(r.setFromEnum(v)) - case int64: - return trace.Wrap(r.setFromEnum(int32(v))) - case int: - return trace.Wrap(r.setFromEnum(int32(v))) - case float64: - return trace.Wrap(r.setFromEnum(int32(v))) - case float32: - return trace.Wrap(r.setFromEnum(int32(v))) - default: - return trace.BadParameter("RequireMFAType invalid type %T", val) - } - return nil -} - -// setFromEnum sets the value from enum value as int32. -func (r *RequireMFAType) setFromEnum(val int32) error { - if _, ok := RequireMFAType_name[val]; !ok { - return trace.BadParameter("invalid required mfa mode %v", val) - } - *r = RequireMFAType(val) - return nil +func (r *RequireMFAType) decode(val any) error { + err := decodeEnum(r, val, map[any]RequireMFAType{ + "": RequireMFAType_OFF, // default to off + false: RequireMFAType_OFF, + true: RequireMFAType_SESSION, + RequireMFATypeHardwareKeyString: RequireMFAType_SESSION_AND_HARDWARE_KEY, + RequireMFATypeHardwareKeyTouchString: RequireMFAType_HARDWARE_KEY_TOUCH, + RequireMFATypeHardwareKeyPINString: RequireMFAType_HARDWARE_KEY_PIN, + RequireMFATypeHardwareKeyTouchAndPINString: RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN, + }, RequireMFAType_name) + return trace.Wrap(err, "failed to decode require mfa type") } diff --git a/api/types/enum.go b/api/types/enum.go new file mode 100644 index 000000000000..dd53c7e62a70 --- /dev/null +++ b/api/types/enum.go @@ -0,0 +1,82 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "strings" + + "github.com/gravitational/trace" +) + +// decodeEnum decodes a protobuf enum from a representational value, usually a bool, +// string, or from the actual enum (int32) value. If the value is valid, it is saved +// in the given enum pointer. +func decodeEnum[T ~int32](p *T, val any, representationMap map[any]T, enumMap map[int32]string) error { + if v, ok := representationMap[val]; ok { + *p = v + return nil + } + + // try parsing as a bool value + if v, ok := val.(string); ok { + switch strings.ToLower(v) { + case "yes", "yeah", "y", "true", "1", "on": + if v, ok := representationMap[true]; ok { + *p = v + return nil + } + case "no", "nope", "n", "false", "0", "off": + if v, ok := representationMap[false]; ok { + *p = v + return nil + } + } + return trace.BadParameter("unknown enum value %v", val) + } + + // parse as enum + var enumVal T + switch v := val.(type) { + case int: + enumVal = T(v) + case int32: + enumVal = T(v) + case int64: + enumVal = T(v) + case float64: + enumVal = T(v) + case float32: + enumVal = T(v) + default: + return trace.BadParameter("unknown enum value %v", val) + } + + if err := checkEnum(enumMap, int32(enumVal)); err != nil { + return trace.BadParameter("unknown enum value %v", val) + } + + *p = enumVal + return nil +} + +// checkEnum checks if the given enum is valid. +func checkEnum(enumMap map[int32]string, val int32) error { + if _, ok := enumMap[val]; ok { + return nil + } + return trace.NotFound("enum %v not found in enum map", val) +} diff --git a/api/types/role.go b/api/types/role.go index 86eb22516aa2..5bcd5e85ebe0 100644 --- a/api/types/role.go +++ b/api/types/role.go @@ -1998,55 +1998,15 @@ func (h CreateHostUserMode) encode() (string, error) { } func (h *CreateHostUserMode) decode(val any) error { - var valS string - switch val := val.(type) { - case int32: - return trace.Wrap(h.setFromEnum(val)) - case int64: - return trace.Wrap(h.setFromEnum(int32(val))) - case int: - return trace.Wrap(h.setFromEnum(int32(val))) - case float64: - return trace.Wrap(h.setFromEnum(int32(val))) - case float32: - return trace.Wrap(h.setFromEnum(int32(val))) - case string: - valS = val - case bool: - if val { - return trace.BadParameter("create_host_user_mode cannot be true, got %v", val) - } - valS = createHostUserModeOffString - default: - return trace.BadParameter("bad value type %T, expected string or int", val) - } - - switch valS { - case "": - *h = CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED - case createHostUserModeOffString: - *h = CreateHostUserMode_HOST_USER_MODE_OFF - case createHostUserModeKeepString: - *h = CreateHostUserMode_HOST_USER_MODE_KEEP - case createHostUserModeInsecureDropString, createHostUserModeDropString: - *h = CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP - default: - return trace.BadParameter("invalid host user mode %v", val) - } - return nil -} - -// setFromEnum sets the value from enum value as int32. -func (h *CreateHostUserMode) setFromEnum(val int32) error { - // Map drop to insecure-drop - if val == int32(CreateHostUserMode_HOST_USER_MODE_DROP) { - val = int32(CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP) - } - if _, ok := CreateHostUserMode_name[val]; !ok { - return trace.BadParameter("invalid host user mode %v", val) - } - *h = CreateHostUserMode(val) - return nil + err := decodeEnum(h, val, map[interface{}]CreateHostUserMode{ + "": CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED, + false: CreateHostUserMode_HOST_USER_MODE_OFF, + createHostUserModeOffString: CreateHostUserMode_HOST_USER_MODE_OFF, + createHostUserModeKeepString: CreateHostUserMode_HOST_USER_MODE_KEEP, + createHostUserModeInsecureDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP, + createHostUserModeDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP, + }, CreateHostUserMode_name) + return trace.Wrap(err, "failed to decode host user mode") } // UnmarshalYAML supports parsing CreateHostUserMode from string. @@ -2114,28 +2074,13 @@ func (h CreateDatabaseUserMode) encode() (string, error) { } func (h *CreateDatabaseUserMode) decode(val any) error { - var str string - switch val := val.(type) { - case string: - str = val - default: - return trace.BadParameter("bad value type %T, expected string", val) - } - - switch str { - case "": - *h = CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED - case createDatabaseUserModeOffString: - *h = CreateDatabaseUserMode_DB_USER_MODE_OFF - case createDatabaseUserModeKeepString: - *h = CreateDatabaseUserMode_DB_USER_MODE_KEEP - case createDatabaseUserModeBestEffortDropString: - *h = CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP - default: - return trace.BadParameter("invalid database user mode %v", val) - } - - return nil + err := decodeEnum(h, val, map[interface{}]CreateDatabaseUserMode{ + "": CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED, + createDatabaseUserModeOffString: CreateDatabaseUserMode_DB_USER_MODE_OFF, + createDatabaseUserModeKeepString: CreateDatabaseUserMode_DB_USER_MODE_KEEP, + createDatabaseUserModeBestEffortDropString: CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP, + }, CreateDatabaseUserMode_name) + return trace.Wrap(err, "failed to decode require mfa type") } // UnmarshalYAML supports parsing CreateDatabaseUserMode from string.