diff --git a/cmd/kured/main.go b/cmd/kured/main.go index 96e229f7d..96b754130 100644 --- a/cmd/kured/main.go +++ b/cmd/kured/main.go @@ -244,26 +244,14 @@ func main() { log.Infof("Reboot schedule: %v", window) log.Infof("Reboot method: %s", rebootMethod) - var rebooter reboot.Rebooter - switch { - case rebootMethod == "command": - log.Infof("Reboot command: %s", rebootCommand) - rebooter = reboot.NewCommandRebooter(rebootCommand) - case rebootMethod == "signal": - log.Infof("Reboot signal: %v", rebootSignal) - rebooter = reboot.NewSignalRebooter(rebootSignal) - default: - log.Fatalf("Invalid reboot-method configured: %s", rebootMethod) - } - - var checker checkers.Checker - // An override of rebootSentinelCommand means a privileged command - if rebootSentinelCommand != "" { - log.Infof("Sentinel checker is (privileged) user provided command: %s", rebootSentinelCommand) - checker = checkers.NewCommandChecker(rebootSentinelCommand) - } else { - log.Infof("Sentinel checker is (unprivileged) testing for the presence of: %s", rebootSentinelFile) - checker = checkers.NewFileRebootChecker(rebootSentinelFile) + rebooter, err := reboot.NewRebooter(rebootMethod, rebootCommand, rebootSignal) + if err != nil { + log.Fatalf("Failed to build rebooter: %v", err) + } + + rebootChecker, err := checkers.NewRebootChecker(rebootSentinelCommand, rebootSentinelFile) + if err != nil { + log.Fatalf("Failed to build reboot checker: %v", err) } config, err := rest.InClusterConfig() @@ -289,8 +277,8 @@ func main() { } lock := daemonsetlock.New(client, nodeID, dsNamespace, dsName, lockAnnotation, lockTTL, concurrency, lockReleaseDelay) - go rebootAsRequired(nodeID, rebooter, checker, window, lock, client) - go maintainRebootRequiredMetric(nodeID, checker) + go rebootAsRequired(nodeID, rebooter, rebootChecker, window, lock, client) + go maintainRebootRequiredMetric(nodeID, rebootChecker) http.Handle("/metrics", promhttp.Handler()) log.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", metricsHost, metricsPort), nil)) @@ -477,7 +465,7 @@ func uncordon(client *kubernetes.Clientset, node *v1.Node) error { func maintainRebootRequiredMetric(nodeID string, checker checkers.Checker) { for { - if checker.CheckRebootRequired() { + if checker.RebootRequired() { rebootRequiredGauge.WithLabelValues(nodeID).Set(1) } else { rebootRequiredGauge.WithLabelValues(nodeID).Set(0) @@ -594,7 +582,7 @@ func rebootAsRequired(nodeID string, rebooter reboot.Rebooter, checker checkers. // And (2) check if we previously annotated the node that it was in the process of being rebooted, // And finally (3) if it has that annotation, to delete it. // This indicates to other node tools running on the cluster that this node may be a candidate for maintenance - if annotateNodes && !checker.CheckRebootRequired() { + if annotateNodes && !checker.RebootRequired() { if _, ok := node.Annotations[KuredRebootInProgressAnnotation]; ok { err := deleteNodeAnnotation(client, nodeID, KuredRebootInProgressAnnotation) if err != nil { @@ -617,7 +605,7 @@ func rebootAsRequired(nodeID string, rebooter reboot.Rebooter, checker checkers. preferNoScheduleTaint := taints.New(client, nodeID, preferNoScheduleTaintName, v1.TaintEffectPreferNoSchedule) // Remove taint immediately during startup to quickly allow scheduling again. - if !checker.CheckRebootRequired() { + if !checker.RebootRequired() { preferNoScheduleTaint.Disable() } @@ -636,7 +624,7 @@ func rebootAsRequired(nodeID string, rebooter reboot.Rebooter, checker checkers. continue } - if !checker.CheckRebootRequired() { + if !checker.RebootRequired() { log.Infof("Reboot not required") preferNoScheduleTaint.Disable() continue diff --git a/pkg/checkers/checker.go b/pkg/checkers/checker.go index 25d912463..cf71dffea 100644 --- a/pkg/checkers/checker.go +++ b/pkg/checkers/checker.go @@ -1,6 +1,7 @@ package checkers import ( + "fmt" "github.com/google/shlex" "github.com/kubereboot/kured/pkg/util" log "github.com/sirupsen/logrus" @@ -13,7 +14,7 @@ import ( // CheckRebootRequired method which returns a single boolean // clarifying whether a reboot is expected or not. type Checker interface { - CheckRebootRequired() bool + RebootRequired() bool } // FileRebootChecker is the default reboot checker. @@ -22,11 +23,21 @@ type FileRebootChecker struct { FilePath string } -// CheckRebootRequired checks the file presence +func NewRebootChecker(rebootSentinelCommand string, rebootSentinelFile string) (Checker, error) { + // An override of rebootSentinelCommand means a privileged command + if rebootSentinelCommand != "" { + log.Infof("Sentinel checker is (privileged) user provided command: %s", rebootSentinelCommand) + return NewCommandChecker(rebootSentinelCommand) + } + log.Infof("Sentinel checker is (unprivileged) testing for the presence of: %s", rebootSentinelFile) + return NewFileRebootChecker(rebootSentinelFile) +} + +// RebootRequired checks the file presence // needs refactoring to also return an error, instead of leaking it inside the code. // This needs refactoring to get rid of NewCommand // This needs refactoring to only contain file location, instead of CheckCommand -func (rc FileRebootChecker) CheckRebootRequired() bool { +func (rc FileRebootChecker) RebootRequired() bool { if _, err := os.Stat(rc.FilePath); err == nil { log.Infof("Reboot required due to file %s presence", rc.FilePath) return true @@ -36,10 +47,10 @@ func (rc FileRebootChecker) CheckRebootRequired() bool { // NewFileRebootChecker is the constructor for the file based reboot checker // TODO: Add extra input validation on filePath string here -func NewFileRebootChecker(filePath string) *FileRebootChecker { +func NewFileRebootChecker(filePath string) (*FileRebootChecker, error) { return &FileRebootChecker{ FilePath: filePath, - } + }, nil } // CommandChecker is using a custom command to check @@ -52,10 +63,10 @@ type CommandChecker struct { Privileged bool } -// CheckRebootRequired for CommandChecker runs a command without returning +// RebootRequired for CommandChecker runs a command without returning // any eventual error. THis should be later refactored to remove the util wrapper // and return the errors, instead of logging them here. -func (rc CommandChecker) CheckRebootRequired() bool { +func (rc CommandChecker) RebootRequired() bool { var cmdline []string if rc.Privileged { cmdline = util.PrivilegedHostCommand(rc.NamespacePid, rc.CheckCommand) @@ -85,14 +96,14 @@ func (rc CommandChecker) CheckRebootRequired() bool { // NewCommandChecker is the constructor for the commandChecker, and by default // runs new commands in a privileged fashion. -func NewCommandChecker(sentinelCommand string) *CommandChecker { +func NewCommandChecker(sentinelCommand string) (*CommandChecker, error) { cmd, err := shlex.Split(sentinelCommand) if err != nil { - log.Fatalf("Error parsing provided sentinel command: %v", err) + return nil, fmt.Errorf("error parsing provided sentinel command: %v", err) } return &CommandChecker{ CheckCommand: cmd, NamespacePid: 1, Privileged: true, - } + }, nil } diff --git a/pkg/checkers/checker_test.go b/pkg/checkers/checker_test.go index 4df9870c1..683bddcd0 100644 --- a/pkg/checkers/checker_test.go +++ b/pkg/checkers/checker_test.go @@ -33,7 +33,7 @@ func Test_rebootRequired(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := CommandChecker{CheckCommand: tt.args.sentinelCommand, NamespacePid: 1, Privileged: false} - if got := a.CheckRebootRequired(); got != tt.want { + if got := a.RebootRequired(); got != tt.want { t.Errorf("rebootRequired() = %v, want %v", got, tt.want) } }) @@ -62,7 +62,7 @@ func Test_rebootRequired_fatals(t *testing.T) { for _, c := range cases { fatal = false a := CommandChecker{CheckCommand: c.param, NamespacePid: 1, Privileged: false} - a.CheckRebootRequired() + a.RebootRequired() assert.Equal(t, c.expectFatal, fatal) } diff --git a/pkg/reboot/command.go b/pkg/reboot/command.go index 07c17f238..74143028c 100644 --- a/pkg/reboot/command.go +++ b/pkg/reboot/command.go @@ -1,6 +1,7 @@ package reboot import ( + "fmt" "github.com/google/shlex" "github.com/kubereboot/kured/pkg/util" log "github.com/sirupsen/logrus" @@ -22,11 +23,14 @@ func (c CommandRebooter) Reboot() { // NewCommandRebooter is the constructor to create a CommandRebooter from a string not // yet shell lexed. You can skip this constructor if you parse the data correctly first // when instantiating a CommandRebooter instance. -func NewCommandRebooter(rebootCommand string) *CommandRebooter { +func NewCommandRebooter(rebootCommand string) (*CommandRebooter, error) { + if rebootCommand == "" { + return nil, fmt.Errorf("no reboot command specified") + } cmd, err := shlex.Split(rebootCommand) if err != nil { - log.Fatalf("Error parsing provided reboot command: %v", err) + return nil, fmt.Errorf("error %v when parsing reboot command %s", err, rebootCommand) } - return &CommandRebooter{RebootCommand: util.PrivilegedHostCommand(1, cmd)} + return &CommandRebooter{RebootCommand: util.PrivilegedHostCommand(1, cmd)}, nil } diff --git a/pkg/reboot/reboot.go b/pkg/reboot/reboot.go index 84ea93a89..fb81cd5d5 100644 --- a/pkg/reboot/reboot.go +++ b/pkg/reboot/reboot.go @@ -1,5 +1,10 @@ package reboot +import ( + "fmt" + log "github.com/sirupsen/logrus" +) + // Rebooter is the standard interface to use to execute // the reboot, after it has been considered as necessary. // The Reboot method does not expect any return, yet should @@ -7,3 +12,16 @@ package reboot type Rebooter interface { Reboot() } + +func NewRebooter(rebootMethod string, rebootCommand string, rebootSignal int) (Rebooter, error) { + switch { + case rebootMethod == "command": + log.Infof("Reboot command: %s", rebootCommand) + return NewCommandRebooter(rebootCommand) + case rebootMethod == "signal": + log.Infof("Reboot signal: %d", rebootSignal) + return NewSignalRebooter(rebootSignal) + default: + return nil, fmt.Errorf("invalid reboot-method configured %s, expected signal or command", rebootMethod) + } +} diff --git a/pkg/reboot/signal.go b/pkg/reboot/signal.go index ab9adbdbb..ef2595f71 100644 --- a/pkg/reboot/signal.go +++ b/pkg/reboot/signal.go @@ -1,6 +1,7 @@ package reboot import ( + "fmt" "os" "syscall" @@ -31,6 +32,9 @@ func (c SignalRebooter) Reboot() { // NewSignalRebooter is the constructor which sets the signal number. // The constructor does not yet validate any input. It should be done in a later commit. -func NewSignalRebooter(sig int) *SignalRebooter { - return &SignalRebooter{Signal: sig} +func NewSignalRebooter(sig int) (*SignalRebooter, error) { + if sig < 1 { + return nil, fmt.Errorf("invalid signal: %v", sig) + } + return &SignalRebooter{Signal: sig}, nil }