From d5ae1afe40d3cb1b52bd4832731d94edbd0b94ef Mon Sep 17 00:00:00 2001 From: huweiwen Date: Thu, 7 Mar 2024 16:23:48 +0800 Subject: [PATCH] unify node/controller service enabling Every driver supports enable/disable node or controller service, with unified interface. deprecate SERVICE_TYPE env deprecate -run-as-controller flags introduce -run-controller-service and -run-node-service, allowing controller and node service to be enabled/disabled independently. --- main.go | 61 +++++++++------------------ pkg/dbfs/dbfs.go | 11 +++-- pkg/disk/controllerserver.go | 3 +- pkg/disk/disk.go | 34 +++++++-------- pkg/disk/nodeserver.go | 2 +- pkg/ens/ens.go | 14 ++---- pkg/metric/collector.go | 38 ++++++++--------- pkg/metric/metric.go | 4 +- pkg/nas/nas.go | 8 ++-- pkg/oss/oss.go | 8 ++-- pkg/pov/pov.go | 30 ++++++------- pkg/utils/util.go | 50 +++++++++++++++++++++- pkg/utils/util_test.go | 82 ++++++++++++++++++++++++++++++++++++ 13 files changed, 224 insertions(+), 121 deletions(-) diff --git a/main.go b/main.go index 1b13b3299..ede10ef3b 100644 --- a/main.go +++ b/main.go @@ -80,34 +80,22 @@ const ( ) var ( - endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI endpoint") - nodeID = flag.String("nodeid", "", "node id") - runAsController = flag.Bool("run-as-controller", false, "Only run as controller service") - driver = flag.String("driver", TypePluginDISK, "CSI Driver") + endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI endpoint") + nodeID = flag.String("nodeid", "", "node id") + runAsController = flag.Bool("run-as-controller", false, "Only run as controller service (deprecated)") + runControllerService = flag.Bool("run-controller-service", true, "activate CSI controller service") + runNodeService = flag.Bool("run-node-service", true, "activate CSI node service") + driver = flag.String("driver", TypePluginDISK, "CSI Driver") // Deprecated: rootDir is instead by KUBELET_ROOT_DIR env. rootDir = flag.String("rootdir", "/var/lib/kubelet/csi-plugins", "Kubernetes root directory") ) -type globalMetricConfig struct { - enableMetric bool - serviceType string -} - -// Nas CSI Plugin func main() { flag.Var(features.FunctionalMutableFeatureGate, "feature-gates", "A set of key=value pairs that describe feature gates for alpha/experimental features. "+ "Options are:\n"+strings.Join(features.FunctionalMutableFeatureGate.KnownFeatures(), "\n")) flag.Parse() - serviceType := os.Getenv(utils.ServiceType) - - if len(serviceType) == 0 || serviceType == "" { - serviceType = utils.PluginService - } + serviceType := utils.GetServiceType(*runAsController, *runControllerService, *runNodeService) - // When serviceType is neither plugin nor provisioner, the program will exits. - if serviceType != utils.PluginService && serviceType != utils.ProvisionerService { - log.Fatalf("Service type is unknown:%s", serviceType) - } // enable pprof analyse pprofPort := os.Getenv("PPROF_PORT") if pprofPort != "" { @@ -183,13 +171,13 @@ func main() { case TypePluginOSS: go func(endPoint string) { defer wg.Done() - driver := oss.NewDriver(*nodeID, endPoint, meta, *runAsController) + driver := oss.NewDriver(*nodeID, endPoint, meta, serviceType) driver.Run() }(endPointName) case TypePluginDISK: go func(endPoint string) { defer wg.Done() - driver := disk.NewDriver(meta, endPoint, *runAsController) + driver := disk.NewDriver(meta, endPoint, serviceType) driver.Run() }(endPointName) @@ -198,13 +186,13 @@ func main() { case TypePluginDBFS: go func(endPoint string) { defer wg.Done() - driver := dbfs.NewDriver(*nodeID, endPoint) + driver := dbfs.NewDriver(*nodeID, endPoint, serviceType) driver.Run() }(endPointName) case TypePluginENS: go func(endpoint string) { defer wg.Done() - driver := ens.NewDriver(*nodeID, endpoint) + driver := ens.NewDriver(*nodeID, endpoint, serviceType) driver.Run() }(endPointName) case ExtenderAgent: @@ -216,7 +204,7 @@ func main() { case TypePluginPOV: go func(endPoint string) { defer wg.Done() - driver := pov.NewDriver(*nodeID, endPoint, *runAsController) + driver := pov.NewDriver(*nodeID, endPoint, serviceType) driver.Run() }(endPointName) default: @@ -225,34 +213,25 @@ func main() { } servicePort := os.Getenv("SERVICE_PORT") - if len(servicePort) == 0 || servicePort == "" { - switch serviceType { - case utils.PluginService: - servicePort = PluginServicePort - case utils.ProvisionerService: + if servicePort == "" { + servicePort = PluginServicePort + if serviceType&utils.Controller != 0 { servicePort = ProvisionerServicePort - default: } } - metricConfig := &globalMetricConfig{ - true, - "plugin", - } - - enableMetric := os.Getenv("ENABLE_METRIC") version.SetPrometheusVersion() - if enableMetric == "false" { - metricConfig.enableMetric = false + enableMetric := true + if os.Getenv("ENABLE_METRIC") == "false" { + enableMetric = false } - metricConfig.serviceType = serviceType log.Info("CSI is running status.") csiMux := http.NewServeMux() csiMux.HandleFunc("/healthz", healthHandler) log.Infof("Metric listening on address: /healthz") - if metricConfig.enableMetric { - metricHandler := metric.NewMetricHandler(metricConfig.serviceType, driverNames) + if enableMetric && serviceType&utils.Node != 0 { + metricHandler := metric.NewMetricHandler(driverNames) csiMux.Handle("/metrics", metricHandler) log.Infof("Metric listening on address: /metrics") } diff --git a/pkg/dbfs/dbfs.go b/pkg/dbfs/dbfs.go index 743e67522..8f5913e42 100644 --- a/pkg/dbfs/dbfs.go +++ b/pkg/dbfs/dbfs.go @@ -55,7 +55,7 @@ type DBFS struct { } // NewDriver create the identity/node/controller server and dbfs driver -func NewDriver(nodeID, endpoint string) *DBFS { +func NewDriver(nodeID, endpoint string, serviceType utils.ServiceType) *DBFS { log.Infof("Driver: %v version: %v", driverName, version.VERSION) d := &DBFS{} @@ -81,7 +81,12 @@ func NewDriver(nodeID, endpoint string) *DBFS { if region == "" { region, _ = utils.GetMetaData(RegionTag) } - d.controllerServer = NewControllerServer(d.driver, c, region) + if serviceType&utils.Controller != 0 { + d.controllerServer = NewControllerServer(d.driver, c, region) + } + if serviceType&utils.Node != 0 { + d.nodeServer = newNodeServer(d) + } GlobalConfigVar.DbfsClient = c // Global Configs Set @@ -91,7 +96,7 @@ func NewDriver(nodeID, endpoint string) *DBFS { // Run start a new NodeServer func (d *DBFS) Run() { - common.RunCSIServer(d.endpoint, NewIdentityServer(d.driver), d.controllerServer, newNodeServer(d)) + common.RunCSIServer(d.endpoint, NewIdentityServer(d.driver), d.controllerServer, d.nodeServer) } // GlobalConfigSet set global config diff --git a/pkg/disk/controllerserver.go b/pkg/disk/controllerserver.go index 61ee0878f..8aba794b9 100644 --- a/pkg/disk/controllerserver.go +++ b/pkg/disk/controllerserver.go @@ -109,8 +109,7 @@ func NewControllerServer(d *csicommon.CSIDriver, client *crd.Clientset) csi.Cont SnapshotRequestInterval = interval } - serviceType := os.Getenv(utils.ServiceType) - if serviceType == utils.ProvisionerService && installCRD { + if installCRD { checkInstallCRD(client) checkInstallDefaultVolumeSnapshotClass(GlobalConfigVar.SnapClient) } diff --git a/pkg/disk/disk.go b/pkg/disk/disk.go index 6e1643634..d6a0086ec 100644 --- a/pkg/disk/disk.go +++ b/pkg/disk/disk.go @@ -74,7 +74,6 @@ type GlobalConfig struct { ClientSet *kubernetes.Clientset ClusterID string DiskPartitionEnable bool - ControllerService bool BdfHealthCheck bool DiskMultiTenantEnable bool CheckBDFHotPlugin bool @@ -99,7 +98,7 @@ func initDriver() { } // NewDriver create the identity/node/controller server and disk driver -func NewDriver(m metadata.MetadataProvider, endpoint string, runAsController bool) *DISK { +func NewDriver(m metadata.MetadataProvider, endpoint string, serviceType utils.ServiceType) *DISK { initDriver() tmpdisk := &DISK{} tmpdisk.endpoint = endpoint @@ -107,6 +106,12 @@ func NewDriver(m metadata.MetadataProvider, endpoint string, runAsController boo // Config Global vars cfg := GlobalConfigSet(m) + if serviceType&utils.Node != 0 { + GlobalConfigVar.NodeID = metadata.MustGet(m, metadata.InstanceID) + } else { + GlobalConfigVar.NodeID = "not-retrieved" // make csi-common happy + } + csiDriver := csicommon.NewCSIDriver(driverName, version.VERSION, GlobalConfigVar.NodeID) tmpdisk.driver = csiDriver tmpdisk.driver.AddControllerServiceCapabilities([]csi.ControllerServiceCapability_RPC_Type{ @@ -135,9 +140,10 @@ func NewDriver(m metadata.MetadataProvider, endpoint string, runAsController boo // Create GRPC servers tmpdisk.idServer = NewIdentityServer(tmpdisk.driver) - tmpdisk.controllerServer = NewControllerServer(tmpdisk.driver, apiExtentionClient) - - if !runAsController { + if serviceType&utils.Controller != 0 { + tmpdisk.controllerServer = NewControllerServer(tmpdisk.driver, apiExtentionClient) + } + if serviceType&utils.Node != 0 { tmpdisk.nodeServer = NewNodeServer(tmpdisk.driver, m) } @@ -220,19 +226,9 @@ func GlobalConfigSet(m metadata.MetadataProvider) *restclient.Config { } clustID := os.Getenv("CLUSTER_ID") - controllerServerType := false - nodeID := "" - if os.Getenv(utils.ServiceType) == utils.ProvisionerService { - controllerServerType = true - nodeID = "controller" // make csi-common happy - } else { - nodeID = metadata.MustGet(m, metadata.InstanceID) - } - // Global Config Set GlobalConfigVar = GlobalConfig{ Region: metadata.MustGet(m, metadata.RegionID), - NodeID: nodeID, ADControllerEnable: features.FunctionalMutableFeatureGate.Enabled(features.DiskADController) || csiCfg.GetBool("disk-adcontroller-enable", "DISK_AD_CONTROLLER", false), DiskTagEnable: csiCfg.GetBool("disk-tag-by-plugin", "DISK_TAGED_BY_PLUGIN", false), @@ -246,7 +242,6 @@ func GlobalConfigSet(m metadata.MetadataProvider) *restclient.Config { SnapClient: snapClient, ClusterID: clustID, DiskPartitionEnable: csiCfg.GetBool("disk-partition-enable", "DISK_PARTITION_ENABLE", true), - ControllerService: controllerServerType, BdfHealthCheck: csiCfg.GetBool("bdf-health-check", "BDF_HEALTH_CHECK", true), DiskMultiTenantEnable: csiCfg.GetBool("disk-multi-tenant-enable", "DISK_MULTI_TENANT_ENABLE", false), NodeMultiZoneEnable: csiCfg.GetBool("node-multi-zone-enable", "NODE_MULTI_ZONE_ENABLE", false), @@ -273,11 +268,12 @@ func GlobalConfigSet(m metadata.MetadataProvider) *restclient.Config { GlobalConfigVar.ClusterID, ) - if controllerServerType && !csiCfg.GetBool("disk-serial-attach", "DISK_SERIAL_ATTACH", false) { + // if ADController is not enabled, we need SERIAL_ATTACH to recognize old disk + if !GlobalConfigVar.ADControllerEnable || csiCfg.GetBool("disk-serial-attach", "DISK_SERIAL_ATTACH", false) { + GlobalConfigVar.AttachDetachSlots = NewSerialAttachDetachSlots() + } else { log.Infof("Disk parallel attach/detach enabled, please set DISK_SERIAL_ATTACH if you see a lot of InvalidOperation.Conflict error.") GlobalConfigVar.AttachDetachSlots = NewParallelAttachDetachSlots() - } else { - GlobalConfigVar.AttachDetachSlots = NewSerialAttachDetachSlots() } return cfg diff --git a/pkg/disk/nodeserver.go b/pkg/disk/nodeserver.go index 6c9c4ce3d..f8d3ea655 100644 --- a/pkg/disk/nodeserver.go +++ b/pkg/disk/nodeserver.go @@ -163,7 +163,7 @@ func NewNodeServer(d *csicommon.CSIDriver, m metadata.MetadataProvider) csi.Node go checkVfhpOnlineReconcile() } - if !GlobalConfigVar.ControllerService && IsVFNode() && GlobalConfigVar.BdfHealthCheck { + if IsVFNode() && GlobalConfigVar.BdfHealthCheck { go BdfHealthCheck() } diff --git a/pkg/ens/ens.go b/pkg/ens/ens.go index 26561086d..dc58c55c5 100644 --- a/pkg/ens/ens.go +++ b/pkg/ens/ens.go @@ -59,7 +59,7 @@ type ENS struct { func initDriver() {} -func NewDriver(nodeID, endpoint string) *ENS { +func NewDriver(nodeID, endpoint string, serviceType utils.ServiceType) *ENS { initDriver() tmpENS := &ENS{} @@ -79,9 +79,10 @@ func NewDriver(nodeID, endpoint string) *ENS { tmpENS.driver.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER}) tmpENS.idServer = NewIdentityServer(tmpENS.driver) - if GlobalConfigVar.ControllerService { + if serviceType&utils.Controller != 0 { tmpENS.controllerServer = NewControllerServer(tmpENS.driver) - } else { + } + if serviceType&utils.Node != 0 { tmpENS.nodeServer = NewNodeServer(tmpENS.driver) } return tmpENS @@ -155,11 +156,6 @@ func NewGlobalConfig() { detachBeforeAttach = true } - controllerServerType := false - if os.Getenv(utils.ServiceType) == utils.ProvisionerService { - controllerServerType = true - } - GlobalConfigVar = GlobalConfig{ KClient: kubeClient, InstanceID: instanceID, @@ -168,7 +164,6 @@ func NewGlobalConfig() { RegionID: regionID, EnableAttachDetachController: attachDetachController, DetachBeforeAttach: detachBeforeAttach, - ControllerService: controllerServerType, } } @@ -177,7 +172,6 @@ type GlobalConfig struct { InstanceID string ClusterID string DetachBeforeAttach bool - ControllerService bool EnableDiskPartition string EnableAttachDetachController string diff --git a/pkg/metric/collector.go b/pkg/metric/collector.go index b46d993c6..0406572f3 100644 --- a/pkg/metric/collector.go +++ b/pkg/metric/collector.go @@ -45,34 +45,32 @@ type CSICollector struct { } // newCSICollector method returns the CSICollector object -func newCSICollector(metricType string, driverNames []string) error { +func newCSICollector(driverNames []string) error { if csiCollectorInstance != nil { return nil } collectors := make(map[string]Collector) - if metricType == pluginService { - enabledDrivers := map[string]struct{}{} - for _, d := range driverNames { - enabledDrivers[d] = struct{}{} - } - for _, reg := range registry { - enabled := len(reg.RelatedDrivers) == 0 - for _, d := range reg.RelatedDrivers { - if _, ok := enabledDrivers[d]; ok { - enabled = true - break - } + enabledDrivers := map[string]struct{}{} + for _, d := range driverNames { + enabledDrivers[d] = struct{}{} + } + for _, reg := range registry { + enabled := len(reg.RelatedDrivers) == 0 + for _, d := range reg.RelatedDrivers { + if _, ok := enabledDrivers[d]; ok { + enabled = true + break } - if enabled { - collector, err := reg.Factory() - if err != nil { - return err - } - collectors[reg.Name] = collector + } + if enabled { + collector, err := reg.Factory() + if err != nil { + return err } + collectors[reg.Name] = collector } - } + csiCollectorInstance = &CSICollector{Collectors: collectors} return nil diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index 48adc00b7..301eac217 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -39,9 +39,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // NewMetricHandler method returns a promHttp object -func NewMetricHandler(serviceType string, driverNames []string) *Handler { +func NewMetricHandler(driverNames []string) *Handler { //csi collector singleton - err := newCSICollector(serviceType, driverNames) + err := newCSICollector(driverNames) if err != nil { logrus.Errorf("Couldn't create collector: %s", err) } diff --git a/pkg/nas/nas.go b/pkg/nas/nas.go index b82fec059..cab8e95b8 100644 --- a/pkg/nas/nas.go +++ b/pkg/nas/nas.go @@ -37,14 +37,14 @@ type NAS struct { nodeServer *nodeServer } -func NewDriver(meta *metadata.Metadata, endpoint, serviceType string) *NAS { +func NewDriver(meta *metadata.Metadata, endpoint string, serviceType utils.ServiceType) *NAS { log.Infof("Driver: %v version: %v", driverName, version.VERSION) var d NAS d.endpoint = endpoint d.identityServer = newIdentityServer(driverName, version.VERSION) - if serviceType == utils.ProvisionerService { + if serviceType&utils.Controller != 0 { config, err := internal.GetControllerConfig(meta) if err != nil { log.Fatalf("Get nas controller config: %v", err) @@ -54,14 +54,14 @@ func NewDriver(meta *metadata.Metadata, endpoint, serviceType string) *NAS { log.Fatalf("Failed to init nas controller server: %v", err) } d.controllerServer = cs - } else { + } + if serviceType&utils.Node != 0 { config, err := internal.GetNodeConfig() if err != nil { log.Fatalf("Get nas node config: %v", err) } d.nodeServer = newNodeServer(config) } - return &d } diff --git a/pkg/oss/oss.go b/pkg/oss/oss.go index b8a260ddb..048b92c73 100644 --- a/pkg/oss/oss.go +++ b/pkg/oss/oss.go @@ -50,7 +50,7 @@ type OSS struct { } // NewDriver init oss type of csi driver -func NewDriver(nodeID, endpoint string, m metadata.MetadataProvider, runAsController bool) *OSS { +func NewDriver(nodeID, endpoint string, m metadata.MetadataProvider, serviceType utils.ServiceType) *OSS { log.Infof("Driver: %v version: %v", driverName, version.VERSION) d := &OSS{} @@ -71,8 +71,10 @@ func NewDriver(nodeID, endpoint string, m metadata.MetadataProvider, runAsContro d.driver = csiDriver - d.controllerServer = newControllerServer(d.driver) - if !runAsController { + if serviceType&utils.Controller != 0 { + d.controllerServer = newControllerServer(d.driver) + } + if serviceType&utils.Node != 0 { d.nodeServer = newNodeServer(d.driver, m) } return d diff --git a/pkg/pov/pov.go b/pkg/pov/pov.go index 2748fc9f2..bad096830 100644 --- a/pkg/pov/pov.go +++ b/pkg/pov/pov.go @@ -6,6 +6,7 @@ import ( "github.com/kubernetes-sigs/alibaba-cloud-csi-driver/pkg/common" "github.com/kubernetes-sigs/alibaba-cloud-csi-driver/pkg/options" + "github.com/kubernetes-sigs/alibaba-cloud-csi-driver/pkg/utils" log "github.com/sirupsen/logrus" "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/clientcmd" @@ -27,15 +28,16 @@ type PoV struct { func initDriver() {} -func NewDriver(nodeID, endpoint string, runAsController bool) *PoV { +func NewDriver(nodeID, endpoint string, serviceType utils.ServiceType) *PoV { initDriver() poV := &PoV{} poV.endpoint = endpoint - newGlobalConfig(runAsController) + newGlobalConfig() - if runAsController { + if serviceType&utils.Controller != 0 { poV.controllerService = newControllerService() - } else { + } + if serviceType&utils.Node != 0 { poV.nodeService = newNodeService() } @@ -48,7 +50,7 @@ func (p *PoV) Run() { common.RunCSIServer(p.endpoint, p, &p.controllerService, &p.nodeService) } -func newGlobalConfig(runAsController bool) { +func newGlobalConfig() { cfg, err := clientcmd.BuildConfigFromFlags(options.MasterURL, options.Kubeconfig) if err != nil { log.Fatalf("newGlobalConfig: build kubeconfig failed: %v", err) @@ -76,18 +78,16 @@ func newGlobalConfig(runAsController bool) { } GlobalConfigVar = GlobalConfig{ - controllerService: runAsController, - client: kubeClient, - regionID: doc.RegionID, - instanceID: doc.InstanceID, - zoneID: doc.ZoneID, + client: kubeClient, + regionID: doc.RegionID, + instanceID: doc.InstanceID, + zoneID: doc.ZoneID, } } type GlobalConfig struct { - regionID string - instanceID string - zoneID string - controllerService bool - client kubernetes.Interface + regionID string + instanceID string + zoneID string + client kubernetes.Interface } diff --git a/pkg/utils/util.go b/pkg/utils/util.go index 165797e06..fd72f9454 100644 --- a/pkg/utils/util.go +++ b/pkg/utils/util.go @@ -70,7 +70,7 @@ const ( // RunvRunTimeTag tag RunvRunTimeTag = "runv" // ServiceType tag - ServiceType = "SERVICE_TYPE" + ServiceTypeEnv = "SERVICE_TYPE" // PluginService represents the csi-plugin type. PluginService = "plugin" // ProvisionerService represents the csi-provisioner type. @@ -96,6 +96,54 @@ const ( GiB = 1024 * 1024 * 1024 ) +type ServiceType int + +const ( + Controller ServiceType = 1 << iota + Node +) + +func GetServiceType(runAsController, runControllerService, runNodeService bool) ServiceType { + serviceType := ServiceType(0) + if runAsController { + log.Warn("-run-as-controller is deprecated, use -run-node-service=false instead") + serviceType = Controller + } + if st := os.Getenv(ServiceTypeEnv); st != "" { + log.Warnf("%s env support is deprecated, use -run-controller-service and -run-node-service instead", ServiceTypeEnv) + switch st { + case PluginService: + if runAsController { + log.Fatalf("%s env is set to %s, but -run-as-controller is also set", ServiceTypeEnv, st) + } + serviceType = Node + case ProvisionerService: + serviceType = Controller + default: + log.Fatalf("invalid %s env value: %s", ServiceTypeEnv, st) + } + } + if serviceType == 0 { + // nothing deprecated was set, use new flags + if runControllerService { + serviceType |= Controller + } + if runNodeService { + serviceType |= Node + } + } + if serviceType == 0 { + log.Warn("no service type activated, this configuration may not be useful") + } + if serviceType&Controller == 0 { + log.Infof("activate CSI controller service") + } + if serviceType&Node == 0 { + log.Infof("activate CSI node service") + } + return serviceType +} + // RoleAuth define STS Token Response type RoleAuth struct { AccessKeyID string diff --git a/pkg/utils/util_test.go b/pkg/utils/util_test.go index c2a33dc54..79309344f 100644 --- a/pkg/utils/util_test.go +++ b/pkg/utils/util_test.go @@ -17,8 +17,10 @@ limitations under the License. package utils import ( + "fmt" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" k8smount "k8s.io/mount-utils" ) @@ -126,6 +128,86 @@ func TestCmdValid(t *testing.T) { assert.Nil(t, CheckCmdArgs(cmd, strings.Split(cmd, " ")[1:]...))*/ } +func TestGetServiceType(t *testing.T) { + logrus.StandardLogger().ExitFunc = func(code int) { + // panic on fatal for testing + panic(fmt.Sprintf("exit on fatal log code %d", code)) + } + tests := []struct { + name string + runAsController bool + runControllerService bool + runNodeService bool + serviceTypeEnv string + want ServiceType + fatal bool + }{ + { + name: "default", + runControllerService: true, + runNodeService: true, + want: Controller | Node, + }, + { + name: "Run as controller", + runAsController: true, + want: Controller, + }, + { + name: "env provisioner", + serviceTypeEnv: ProvisionerService, + want: Controller, + }, + { + name: "env plugin", + serviceTypeEnv: PluginService, + want: Node, + }, + { + name: "Run controller", + runControllerService: true, + runNodeService: false, + want: Controller, + }, + { + name: "Run node", + runControllerService: false, + runNodeService: true, + want: Node, + }, + { + name: "nothing", + runControllerService: false, + runNodeService: false, + want: 0, + }, + { + name: "invalid env", + serviceTypeEnv: "invalid", + fatal: true, + }, + { + name: "conflict env and run-as-controller", + runAsController: true, + serviceTypeEnv: PluginService, + fatal: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("SERVICE_TYPE", tt.serviceTypeEnv) + if tt.fatal { + assert.Panics(t, func() { + GetServiceType(tt.runAsController, tt.runControllerService, tt.runNodeService) + }) + return + } + got := GetServiceType(tt.runAsController, tt.runControllerService, tt.runNodeService) + assert.Equal(t, tt.want, got) + }) + } +} + func TestIsDirTmpfs(t *testing.T) { mounter := k8smount.NewFakeMounter([]k8smount.MountPoint{ {