diff --git a/subcommand/inject-connect/command.go b/subcommand/inject-connect/command.go index ff631c6b8c..4420b774e2 100644 --- a/subcommand/inject-connect/command.go +++ b/subcommand/inject-connect/command.go @@ -14,6 +14,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" connectinject "github.com/hashicorp/consul-k8s/connect-inject" @@ -163,11 +164,11 @@ func (c *Command) init() { flags.Merge(c.flagSet, c.http.Flags()) c.help = flags.Usage(help, c.flagSet) - // Wait on an interrupt for exit, be sure to init it before running + // Wait on an interrupt or terminate for exit, be sure to init it before running // the controller so that we don't receive an interrupt before it's ready. if c.sigCh == nil { c.sigCh = make(chan os.Signal, 1) - signal.Notify(c.sigCh, os.Interrupt) + signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM) } } @@ -390,8 +391,9 @@ func (c *Command) Run(args []string) int { }() select { - // Interrupted, gracefully exit. - case <-c.sigCh: + // Interrupted/terminated, gracefully exit. + case sig := <-c.sigCh: + c.UI.Info(fmt.Sprintf("%s received, shutting down", sig)) if err := server.Close(); err != nil { c.UI.Error(fmt.Sprintf("shutting down server: %v", err)) return 1 @@ -417,7 +419,11 @@ func (c *Command) Run(args []string) int { } func (c *Command) interrupt() { - c.sigCh <- os.Interrupt + c.sendSignal(syscall.SIGINT) +} + +func (c *Command) sendSignal(sig os.Signal) { + c.sigCh <- sig } func (c *Command) handleReady(rw http.ResponseWriter, req *http.Request) { diff --git a/subcommand/inject-connect/command_test.go b/subcommand/inject-connect/command_test.go index 5bb04a922e..005242d857 100644 --- a/subcommand/inject-connect/command_test.go +++ b/subcommand/inject-connect/command_test.go @@ -3,6 +3,7 @@ package connectinject import ( "fmt" "os" + "syscall" "testing" "time" @@ -206,38 +207,46 @@ func TestRun_CommandFailsWithInvalidListener(t *testing.T) { require.Contains(t, ui.ErrorWriter.String(), "Error listening: listen tcp: address 999999: missing port in address") } -// Test that when healthchecks are enabled that SIGINT exits the +// Test that when healthchecks are enabled that SIGINT/SIGTERM exits the // command cleanly. -func TestRun_CommandExitsCleanlyAfterSigInt(t *testing.T) { - k8sClient := fake.NewSimpleClientset() - ui := cli.NewMockUi() - cmd := Command{ - UI: ui, - clientset: k8sClient, - } - ports := freeport.MustTake(1) +func TestRun_CommandExitsCleanlyAfterSignal(t *testing.T) { - // NOTE: This url doesn't matter because Consul is never called. - os.Setenv(api.HTTPAddrEnvName, "http://0.0.0.0:9999") - defer os.Unsetenv(api.HTTPAddrEnvName) + t.Run("SIGINT", testSignalHandling(syscall.SIGINT)) + t.Run("SIGTERM", testSignalHandling(syscall.SIGTERM)) +} - // Start the command asynchronously and then we'll send an interrupt. - exitChan := runCommandAsynchronously(&cmd, []string{ - "-consul-k8s-image", "hashicorp/consul-k8s", - "-enable-health-checks-controller=true", - "-listen", fmt.Sprintf(":%d", ports[0]), - }) +func testSignalHandling(sig os.Signal) func(*testing.T) { + return func(t *testing.T) { + k8sClient := fake.NewSimpleClientset() + ui := cli.NewMockUi() + cmd := Command{ + UI: ui, + clientset: k8sClient, + } + ports := freeport.MustTake(1) + + // NOTE: This url doesn't matter because Consul is never called. + os.Setenv(api.HTTPAddrEnvName, "http://0.0.0.0:9999") + defer os.Unsetenv(api.HTTPAddrEnvName) + + // Start the command asynchronously and then we'll send an interrupt. + exitChan := runCommandAsynchronously(&cmd, []string{ + "-consul-k8s-image", "hashicorp/consul-k8s", + "-enable-health-checks-controller=true", + "-listen", fmt.Sprintf(":%d", ports[0]), + }) - // Send the interrupt. - cmd.interrupt() + // Send the signal + cmd.sendSignal(sig) - // Assert that it exits cleanly or timeout. - select { - case exitCode := <-exitChan: - require.Equal(t, 0, exitCode, ui.ErrorWriter.String()) - case <-time.After(time.Second * 1): - // Fail if the stopCh was not caught. - require.Fail(t, "timeout waiting for command to exit") + // Assert that it exits cleanly or timeout. + select { + case exitCode := <-exitChan: + require.Equal(t, 0, exitCode, ui.ErrorWriter.String()) + case <-time.After(time.Second * 1): + // Fail if the stopCh was not caught. + require.Fail(t, "timeout waiting for command to exit") + } } } diff --git a/subcommand/lifecycle-sidecar/command.go b/subcommand/lifecycle-sidecar/command.go index 4e0b1f2ecf..df4057ea4c 100644 --- a/subcommand/lifecycle-sidecar/command.go +++ b/subcommand/lifecycle-sidecar/command.go @@ -9,6 +9,7 @@ import ( "os/signal" "strings" "sync" + "syscall" "time" "github.com/hashicorp/consul-k8s/subcommand/flags" @@ -47,12 +48,12 @@ func (c *Command) init() { flags.Merge(c.flagSet, c.http.Flags()) c.help = flags.Usage(help, c.flagSet) - // Wait on an interrupt to exit. This channel must be initialized before + // Wait on an interrupt or terminate to exit. This channel must be initialized before // Run() is called so that there are no race conditions where the channel // is not defined. if c.sigCh == nil { c.sigCh = make(chan os.Signal, 1) - signal.Notify(c.sigCh, os.Interrupt) + signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM) } } @@ -106,12 +107,12 @@ func (c *Command) Run(args []string) int { logger.Info("successfully synced service", "output", strings.TrimSpace(string(output))) } - // Re-loop after syncPeriod or exit if we receive an interrupt. + // Re-loop after syncPeriod or exit if we receive interrupt or terminate signals. select { case <-time.After(c.flagSyncPeriod): continue - case <-c.sigCh: - logger.Info("SIGINT received, shutting down") + case sig := <-c.sigCh: + logger.Info(fmt.Sprintf("%s received, shutting down", sig)) return 0 } } @@ -164,7 +165,11 @@ func (c *Command) parseConsulFlags() []string { // interrupt sends os.Interrupt signal to the command // so it can exit gracefully. This function is needed for tests func (c *Command) interrupt() { - c.sigCh <- os.Interrupt + c.sendSignal(syscall.SIGINT) +} + +func (c *Command) sendSignal(sig os.Signal) { + c.sigCh <- sig } func (c *Command) Synopsis() string { return synopsis } diff --git a/subcommand/lifecycle-sidecar/command_test.go b/subcommand/lifecycle-sidecar/command_test.go index ed4767d721..7f9e16d04d 100644 --- a/subcommand/lifecycle-sidecar/command_test.go +++ b/subcommand/lifecycle-sidecar/command_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "path/filepath" + "syscall" "testing" "time" @@ -25,6 +26,37 @@ func TestRun_Defaults(t *testing.T) { require.Equal(t, "consul", cmd.flagConsulBinary) } +func TestRun_ExitsCleanlyonSignals(t *testing.T) { + t.Run("SIGINT", testRunSignalHandling(syscall.SIGINT)) + t.Run("SIGTERM", testRunSignalHandling(syscall.SIGTERM)) +} + +func testRunSignalHandling(sig os.Signal) func(*testing.T) { + return func(t *testing.T) { + tmpDir, configFile := createServicesTmpFile(t, servicesRegistration) + defer os.RemoveAll(tmpDir) + + ui := cli.NewMockUi() + cmd := Command{ + UI: ui, + } + // Run async because we need to kill it when the test is over. + exitChan := runCommandAsynchronously(&cmd, []string{ + "-service-config", configFile, + }) + cmd.sendSignal(sig) + + // Assert that it exits cleanly or timeout. + select { + case exitCode := <-exitChan: + require.Equal(t, 0, exitCode, ui.ErrorWriter.String()) + case <-time.After(time.Second * 1): + // Fail if the signal was not caught. + require.Fail(t, "timeout waiting for command to exit") + } + } +} + func TestRun_FlagValidation(t *testing.T) { t.Parallel() cases := []struct { diff --git a/subcommand/sync-catalog/command.go b/subcommand/sync-catalog/command.go index 5bf20a9fe1..386cc54c4e 100644 --- a/subcommand/sync-catalog/command.go +++ b/subcommand/sync-catalog/command.go @@ -9,6 +9,7 @@ import ( "os/signal" "regexp" "sync" + "syscall" "time" "github.com/deckarep/golang-set" @@ -145,12 +146,12 @@ func (c *Command) init() { c.help = flags.Usage(help, c.flags) - // Wait on an interrupt to exit. This channel must be initialized before + // Wait on an interrupt or terminate to exit. This channel must be initialized before // Run() is called so that there are no race conditions where the channel // is not defined. if c.sigCh == nil { c.sigCh = make(chan os.Signal, 1) - signal.Notify(c.sigCh, os.Interrupt) + signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM) } } @@ -345,8 +346,9 @@ func (c *Command) Run(args []string) int { } return 1 - // Interrupted, gracefully exit - case <-c.sigCh: + // Interrupted/terminated, gracefully exit + case sig := <-c.sigCh: + c.logger.Info(fmt.Sprintf("%s received, shutting down", sig)) cancelF() if toConsulCh != nil { <-toConsulCh @@ -379,7 +381,11 @@ func (c *Command) Help() string { // interrupt sends os.Interrupt signal to the command // so it can exit gracefully. This function is needed for tests func (c *Command) interrupt() { - c.sigCh <- os.Interrupt + c.sendSignal(syscall.SIGINT) +} + +func (c *Command) sendSignal(sig os.Signal) { + c.sigCh <- sig } func (c *Command) validateFlags() error { diff --git a/subcommand/sync-catalog/command_test.go b/subcommand/sync-catalog/command_test.go index 3c62a2b68b..dee8faa8b2 100644 --- a/subcommand/sync-catalog/command_test.go +++ b/subcommand/sync-catalog/command_test.go @@ -2,6 +2,8 @@ package synccatalog import ( "context" + "os" + "syscall" "testing" "time" @@ -81,6 +83,48 @@ func TestRun_Defaults_SyncsConsulServiceToK8s(t *testing.T) { }) } +// Test that the command exits cleanly on signals +func TestRun_ExitCleanlyOnSignals(t *testing.T) { + t.Run("SIGINT", testSignalHandling(syscall.SIGINT)) + t.Run("SIGTERM", testSignalHandling(syscall.SIGTERM)) +} + +func testSignalHandling(sig os.Signal) func(*testing.T) { + return func(t *testing.T) { + k8s, testServer := completeSetup(t) + defer testServer.Stop() + + // Run the command. + ui := cli.NewMockUi() + cmd := Command{ + UI: ui, + clientset: k8s, + logger: hclog.New(&hclog.LoggerOptions{ + Name: t.Name(), + Level: hclog.Debug, + }), + } + + exitChan := runCommandAsynchronously(&cmd, []string{ + "-http-addr", testServer.HTTPAddr, + }) + cmd.sendSignal(sig) + + // Assert that it exits cleanly or timeout. + select { + case exitCode := <-exitChan: + require.Equal(t, 0, exitCode, ui.ErrorWriter.String()) + + // For some reason, this command cannot exit within 1s, + // so it's set higher than other tests in other commands + // to allow it to exit properly + case <-time.After(time.Second * 5): + // Fail if the signal was not caught. + require.Fail(t, "timeout waiting for command to exit") + } + } +} + // Test that when -add-k8s-namespace-suffix flag is used // k8s namespaces are appended to the service names synced to Consul func TestRun_ToConsulWithAddK8SNamespaceSuffix(t *testing.T) { diff --git a/subcommand/webhook-cert-manager/command.go b/subcommand/webhook-cert-manager/command.go index afceb42b15..b7817c00f9 100644 --- a/subcommand/webhook-cert-manager/command.go +++ b/subcommand/webhook-cert-manager/command.go @@ -13,6 +13,7 @@ import ( "os/signal" "strings" "sync" + "syscall" "time" "github.com/hashicorp/consul-k8s/helper/cert" @@ -65,12 +66,12 @@ func (c *Command) init() { flags.Merge(c.flagSet, c.k8s.Flags()) c.help = flags.Usage(help, c.flagSet) - // Wait on an interrupt to exit. This channel must be initialized before + // Wait on an interrupt or terminate to exit. This channel must be initialized before // Run() is called so that there are no race conditions where the channel // is not defined. if c.sigCh == nil { c.sigCh = make(chan os.Signal, 1) - signal.Notify(c.sigCh, os.Interrupt) + signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM) } } @@ -167,11 +168,12 @@ func (c *Command) Run(args []string) int { go c.certWatcher(ctx, certCh, c.clientset, c.logger) - // We define a signal handler for OS interrupts, and when an SIGINT is received, + // We define a signal handler for OS interrupts, and when an SIGINT or SIGTERM is received, // we gracefully shut down, by first stopping our cert notifiers and then cancelling // all the contexts that have been created by the process. select { - case <-c.sigCh: + case sig := <-c.sigCh: + c.logger.Info(fmt.Sprintf("%s received, shutting down", sig)) cancelFunc() for _, notifier := range notifiers { notifier.Stop() @@ -367,7 +369,11 @@ func (c *Command) Synopsis() string { // interrupt sends os.Interrupt signal to the command // so it can exit gracefully. This function is needed for tests func (c *Command) interrupt() { - c.sigCh <- os.Interrupt + c.sendSignal(syscall.SIGINT) +} + +func (c *Command) sendSignal(sig os.Signal) { + c.sigCh <- sig } const synopsis = "Starts the Consul Kubernetes webhook-cert-manager" diff --git a/subcommand/webhook-cert-manager/command_test.go b/subcommand/webhook-cert-manager/command_test.go index bb26fc71b5..c550fedfad 100644 --- a/subcommand/webhook-cert-manager/command_test.go +++ b/subcommand/webhook-cert-manager/command_test.go @@ -4,6 +4,7 @@ import ( "context" "io/ioutil" "os" + "syscall" "testing" "time" @@ -18,6 +19,83 @@ import ( "k8s.io/client-go/kubernetes/fake" ) +func TestRun_ExitsCleanlyOnSignals(t *testing.T) { + t.Run("SIGINT", testSignalHandling(syscall.SIGINT)) + t.Run("SIGTERM", testSignalHandling(syscall.SIGTERM)) +} + +func testSignalHandling(sig os.Signal) func(*testing.T) { + return func(t *testing.T) { + webhookConfigOneName := "webhookOne" + webhookConfigTwoName := "webhookTwo" + + caBundleOne := []byte("bootstrapped-CA-one") + caBundleTwo := []byte("bootstrapped-CA-two") + + webhookOne := &admissionv1beta1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: webhookConfigOneName, + }, + Webhooks: []admissionv1beta1.MutatingWebhook{ + { + Name: "webhook-under-test", + ClientConfig: admissionv1beta1.WebhookClientConfig{ + CABundle: caBundleOne, + }, + }, + }, + } + webhookTwo := &admissionv1beta1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: webhookConfigTwoName, + }, + Webhooks: []admissionv1beta1.MutatingWebhook{ + { + Name: "webhookOne-under-test", + ClientConfig: admissionv1beta1.WebhookClientConfig{ + CABundle: caBundleTwo, + }, + }, + { + Name: "webhookTwo-under-test", + ClientConfig: admissionv1beta1.WebhookClientConfig{ + CABundle: caBundleTwo, + }, + }, + }, + } + + k8s := fake.NewSimpleClientset(webhookOne, webhookTwo) + ui := cli.NewMockUi() + cmd := Command{ + UI: ui, + clientset: k8s, + } + cmd.init() + + file, err := ioutil.TempFile("", "config.json") + require.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.Write([]byte(configFile)) + require.NoError(t, err) + + exitCh := runCommandAsynchronously(&cmd, []string{ + "-config-file", file.Name(), + }) + cmd.sendSignal(sig) + + // Assert that it exits cleanly or timeout. + select { + case exitCode := <-exitCh: + require.Equal(t, 0, exitCode, ui.ErrorWriter.String()) + case <-time.After(time.Second * 1): + // Fail if the signal was not caught. + require.Fail(t, "timeout waiting for command to exit") + } + } +} + func TestRun_FlagValidation(t *testing.T) { t.Parallel()