Skip to content

Commit

Permalink
Refactor enum decoding for reusability and readability.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Oct 4, 2024
1 parent e77bf35 commit c265007
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 126 deletions.
66 changes: 11 additions & 55 deletions api/types/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -1118,59 +1118,15 @@ func (r *RequireMFAType) encode() (interface{}, error) {
// decode RequireMFAType from a string or boolean. This is necessary for
// backwards compatibility with the json/yaml tag "require_session_mfa",
// which used to be a boolean.
func (r *RequireMFAType) decode(val interface{}) error {
switch v := val.(type) {
case string:
switch v {
case RequireMFATypeHardwareKeyString:
*r = RequireMFAType_SESSION_AND_HARDWARE_KEY
case RequireMFATypeHardwareKeyTouchString:
*r = RequireMFAType_HARDWARE_KEY_TOUCH
case RequireMFATypeHardwareKeyPINString:
*r = RequireMFAType_HARDWARE_KEY_PIN
case RequireMFATypeHardwareKeyTouchAndPINString:
*r = RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN
case "":
// default to off
*r = RequireMFAType_OFF
default:
// try parsing as a boolean
switch strings.ToLower(v) {
case "yes", "yeah", "y", "true", "1", "on":
*r = RequireMFAType_SESSION
case "no", "nope", "n", "false", "0", "off":
*r = RequireMFAType_OFF
default:
return trace.BadParameter("RequireMFAType invalid value %v", val)
}
}
case bool:
if v {
*r = RequireMFAType_SESSION
} else {
*r = RequireMFAType_OFF
}
case int32:
return trace.Wrap(r.setFromEnum(v))
case int64:
return trace.Wrap(r.setFromEnum(int32(v)))
case int:
return trace.Wrap(r.setFromEnum(int32(v)))
case float64:
return trace.Wrap(r.setFromEnum(int32(v)))
case float32:
return trace.Wrap(r.setFromEnum(int32(v)))
default:
return trace.BadParameter("RequireMFAType invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (r *RequireMFAType) setFromEnum(val int32) error {
if _, ok := RequireMFAType_name[val]; !ok {
return trace.BadParameter("invalid required mfa mode %v", val)
}
*r = RequireMFAType(val)
return nil
func (r *RequireMFAType) decode(val any) error {
err := decodeEnum(r, val, map[any]RequireMFAType{
"": RequireMFAType_OFF, // default to off
false: RequireMFAType_OFF,
true: RequireMFAType_SESSION,
RequireMFATypeHardwareKeyString: RequireMFAType_SESSION_AND_HARDWARE_KEY,
RequireMFATypeHardwareKeyTouchString: RequireMFAType_HARDWARE_KEY_TOUCH,
RequireMFATypeHardwareKeyPINString: RequireMFAType_HARDWARE_KEY_PIN,
RequireMFATypeHardwareKeyTouchAndPINString: RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN,
}, RequireMFAType_name)
return trace.Wrap(err, "failed to decode require mfa type")
}
82 changes: 82 additions & 0 deletions api/types/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright 2024 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package types

import (
"strings"

"github.com/gravitational/trace"
)

// decodeEnum decodes a protobuf enum from a representational value, usually a bool,
// string, or from the actual enum (int32) value. If the value is valid, it is saved
// in the given enum pointer.
func decodeEnum[T ~int32](p *T, val any, representationMap map[any]T, enumMap map[int32]string) error {
if v, ok := representationMap[val]; ok {
*p = v
return nil
}

// try parsing as a bool value
if v, ok := val.(string); ok {
switch strings.ToLower(v) {
case "yes", "yeah", "y", "true", "1", "on":
if v, ok := representationMap[true]; ok {
*p = v
return nil
}
case "no", "nope", "n", "false", "0", "off":
if v, ok := representationMap[false]; ok {
*p = v
return nil
}
}
return trace.BadParameter("unknown enum value %v", val)
}

// parse as enum
var enumVal T
switch v := val.(type) {
case int:
enumVal = T(v)
case int32:
enumVal = T(v)
case int64:
enumVal = T(v)
case float64:
enumVal = T(v)
case float32:
enumVal = T(v)
default:
return trace.BadParameter("unknown enum value %v", val)
}

if err := checkEnum(enumMap, int32(enumVal)); err != nil {
return trace.BadParameter("unknown enum value %v", val)
}

*p = enumVal
return nil
}

// checkEnum checks if the given enum is valid.
func checkEnum(enumMap map[int32]string, val int32) error {
if _, ok := enumMap[val]; ok {
return nil
}
return trace.NotFound("enum %v not found in enum map", val)
}
87 changes: 16 additions & 71 deletions api/types/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -1998,55 +1998,15 @@ func (h CreateHostUserMode) encode() (string, error) {
}

func (h *CreateHostUserMode) decode(val any) error {
var valS string
switch val := val.(type) {
case int32:
return trace.Wrap(h.setFromEnum(val))
case int64:
return trace.Wrap(h.setFromEnum(int32(val)))
case int:
return trace.Wrap(h.setFromEnum(int32(val)))
case float64:
return trace.Wrap(h.setFromEnum(int32(val)))
case float32:
return trace.Wrap(h.setFromEnum(int32(val)))
case string:
valS = val
case bool:
if val {
return trace.BadParameter("create_host_user_mode cannot be true, got %v", val)
}
valS = createHostUserModeOffString
default:
return trace.BadParameter("bad value type %T, expected string or int", val)
}

switch valS {
case "":
*h = CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED
case createHostUserModeOffString:
*h = CreateHostUserMode_HOST_USER_MODE_OFF
case createHostUserModeKeepString:
*h = CreateHostUserMode_HOST_USER_MODE_KEEP
case createHostUserModeInsecureDropString, createHostUserModeDropString:
*h = CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP
default:
return trace.BadParameter("invalid host user mode %v", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (h *CreateHostUserMode) setFromEnum(val int32) error {
// Map drop to insecure-drop
if val == int32(CreateHostUserMode_HOST_USER_MODE_DROP) {
val = int32(CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP)
}
if _, ok := CreateHostUserMode_name[val]; !ok {
return trace.BadParameter("invalid host user mode %v", val)
}
*h = CreateHostUserMode(val)
return nil
err := decodeEnum(h, val, map[interface{}]CreateHostUserMode{
"": CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED,
false: CreateHostUserMode_HOST_USER_MODE_OFF,
createHostUserModeOffString: CreateHostUserMode_HOST_USER_MODE_OFF,
createHostUserModeKeepString: CreateHostUserMode_HOST_USER_MODE_KEEP,
createHostUserModeInsecureDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP,
createHostUserModeDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP,
}, CreateHostUserMode_name)
return trace.Wrap(err, "failed to decode host user mode")
}

// UnmarshalYAML supports parsing CreateHostUserMode from string.
Expand Down Expand Up @@ -2114,28 +2074,13 @@ func (h CreateDatabaseUserMode) encode() (string, error) {
}

func (h *CreateDatabaseUserMode) decode(val any) error {
var str string
switch val := val.(type) {
case string:
str = val
default:
return trace.BadParameter("bad value type %T, expected string", val)
}

switch str {
case "":
*h = CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED
case createDatabaseUserModeOffString:
*h = CreateDatabaseUserMode_DB_USER_MODE_OFF
case createDatabaseUserModeKeepString:
*h = CreateDatabaseUserMode_DB_USER_MODE_KEEP
case createDatabaseUserModeBestEffortDropString:
*h = CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP
default:
return trace.BadParameter("invalid database user mode %v", val)
}

return nil
err := decodeEnum(h, val, map[interface{}]CreateDatabaseUserMode{
"": CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED,
createDatabaseUserModeOffString: CreateDatabaseUserMode_DB_USER_MODE_OFF,
createDatabaseUserModeKeepString: CreateDatabaseUserMode_DB_USER_MODE_KEEP,
createDatabaseUserModeBestEffortDropString: CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP,
}, CreateDatabaseUserMode_name)
return trace.Wrap(err, "failed to decode require mfa type")
}

// UnmarshalYAML supports parsing CreateDatabaseUserMode from string.
Expand Down

0 comments on commit c265007

Please sign in to comment.