Skip to content

Commit

Permalink
refactor: make apid stop gracefully and be stopped late
Browse files Browse the repository at this point in the history
This fixes apid and machined shutdown sequences to do graceful stop of
gRPC server with timeout.

Also sequences are restructured to stop apid/machined as late as
possible allowing access to the node while the long sequence is running
(e.g. upgrade or reset).

Signed-off-by: Andrey Smirnov <[email protected]>
  • Loading branch information
smira committed Jul 29, 2022
1 parent 0cdf222 commit 2e79052
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 29 deletions.
79 changes: 61 additions & 18 deletions internal/app/apid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ package apid
import (
"context"
"flag"
"fmt"
"log"
"os/signal"
"regexp"
"syscall"
"time"

"github.com/cosi-project/runtime/api/v1alpha1"
"github.com/cosi-project/runtime/pkg/state"
Expand Down Expand Up @@ -46,39 +50,48 @@ func runDebugServer(ctx context.Context) {

// Main is the entrypoint of apid.
func Main() {
if err := apidMain(); err != nil {
log.Fatal(err)
}
}

func apidMain() error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer cancel()

log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds | log.Ltime)

rbacEnabled = flag.Bool("enable-rbac", false, "enable RBAC for Talos API")

flag.Parse()

go runDebugServer(context.TODO())
go runDebugServer(ctx)

if err := startup.RandSeed(); err != nil {
log.Fatalf("failed to seed RNG: %v", err)
return fmt.Errorf("failed to seed RNG: %w", err)
}

runtimeConn, err := grpc.Dial("unix://"+constants.APIRuntimeSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Fatalf("failed to dial runtime connection: %v", err)
return fmt.Errorf("failed to dial runtime connection: %w", err)
}

stateClient := v1alpha1.NewStateClient(runtimeConn)
resources := state.WrapCore(client.NewAdapter(stateClient))

tlsConfig, err := provider.NewTLSConfig(resources)
if err != nil {
log.Fatalf("failed to create remote certificate provider: %+v", err)
return fmt.Errorf("failed to create remote certificate provider: %w", err)
}

serverTLSConfig, err := tlsConfig.ServerConfig()
if err != nil {
log.Fatalf("failed to create OS-level TLS configuration: %v", err)
return fmt.Errorf("failed to create OS-level TLS configuration: %w", err)
}

clientTLSConfig, err := tlsConfig.ClientConfig()
if err != nil {
log.Fatalf("failed to create client TLS config: %v", err)
return fmt.Errorf("failed to create client TLS config: %w", err)
}

backendFactory := apidbackend.NewAPIDFactory(clientTLSConfig)
Expand Down Expand Up @@ -109,9 +122,22 @@ func Main() {
// register future pattern: method should have suffix "Stream"
router.RegisterStreamedRegex("Stream$")

var errGroup errgroup.Group
networkListener, err := factory.NewListener(
factory.Port(constants.ApidPort),
)
if err != nil {
return fmt.Errorf("error creating listner: %w", err)
}

errGroup.Go(func() error {
socketListener, err := factory.NewListener(
factory.Network("unix"),
factory.SocketPath(constants.APISocketPath),
)
if err != nil {
return fmt.Errorf("error creating listner: %w", err)
}

networkServer := func() *grpc.Server {
mode := authz.Disabled
if *rbacEnabled {
mode = authz.Enabled
Expand All @@ -122,9 +148,8 @@ func Main() {
Logger: log.New(log.Writer(), "apid/authz/injector/http ", log.Flags()).Printf,
}

return factory.ListenAndServe(
return factory.NewServer(
router,
factory.Port(constants.ApidPort),
factory.WithDefaultLog(),
factory.ServerOptions(
grpc.Creds(
Expand All @@ -140,18 +165,16 @@ func Main() {
factory.WithUnaryInterceptor(injector.UnaryInterceptor()),
factory.WithStreamInterceptor(injector.StreamInterceptor()),
)
})
}()

errGroup.Go(func() error {
socketServer := func() *grpc.Server {
injector := &authz.Injector{
Mode: authz.MetadataOnly,
Logger: log.New(log.Writer(), "apid/authz/injector/unix ", log.Flags()).Printf,
}

return factory.ListenAndServe(
return factory.NewServer(
router,
factory.Network("unix"),
factory.SocketPath(constants.APISocketPath),
factory.WithDefaultLog(),
factory.ServerOptions(
grpc.CustomCodec(proxy.Codec()), //nolint:staticcheck
Expand All @@ -164,9 +187,29 @@ func Main() {
factory.WithUnaryInterceptor(injector.UnaryInterceptor()),
factory.WithStreamInterceptor(injector.StreamInterceptor()),
)
}()

errGroup, ctx := errgroup.WithContext(ctx)

errGroup.Go(func() error {
return networkServer.Serve(networkListener)
})

if err := errGroup.Wait(); err != nil {
log.Fatalf("listen: %v", err)
}
errGroup.Go(func() error {
return socketServer.Serve(socketListener)
})

errGroup.Go(func() error {
<-ctx.Done()

shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()

factory.ServerGracefulStop(networkServer, shutdownCtx)
factory.ServerGracefulStop(socketServer, shutdownCtx)

return nil
})

return errGroup.Wait()
}
15 changes: 9 additions & 6 deletions internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
LeaveEtcd,
).Append(
"stopServices",
StopServicesForUpgrade,
StopServicesEphemeral,
).Append(
"unmountUser",
UnmountUserDisks,
Expand All @@ -421,9 +421,6 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
).Append(
"upgrade",
Upgrade,
).Append(
"stopEverything",
StopAllServices,
).Append(
"mountBoot",
MountBootPartition,
Expand All @@ -433,6 +430,9 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
).Append(
"unmountBoot",
UnmountBootPartition,
).Append(
"stopEverything",
StopAllServices,
).Append(
"reboot",
Reboot,
Expand All @@ -453,8 +453,8 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList {
)
default:
phases = phases.Append(
"stopEverything",
StopAllServices,
"stopServices",
StopServicesEphemeral,
).Append(
"unmountUser",
UnmountUserDisks,
Expand All @@ -481,6 +481,9 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList {
enableKexec,
"unmountBoot",
UnmountBootPartition,
).Append(
"stopEverything",
StopAllServices,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,11 @@ func StartAllServices(seq runtime.Sequence, data interface{}) (runtime.TaskExecu
}, "startAllServices"
}

// StopServicesForUpgrade represents the StopServicesForUpgrade task.
func StopServicesForUpgrade(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) {
// StopServicesEphemeral represents the StopServicesEphemeral task.
func StopServicesEphemeral(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) {
return func(ctx context.Context, logger *log.Logger, r runtime.Runtime) (err error) {
// stopping 'cri' service stops everything which depends on it (kubelet, etcd, ...)
return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd")
return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd", "trustd")
}, "stopServicesForUpgrade"
}

Expand Down
8 changes: 6 additions & 2 deletions internal/app/machined/pkg/system/services/machined.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log"
"os"
"path/filepath"
"time"

v1alpha1server "github.com/talos-systems/talos/internal/app/machined/internal/server/v1alpha1"
"github.com/talos-systems/talos/internal/app/machined/pkg/runtime"
Expand Down Expand Up @@ -134,15 +135,18 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter
return err
}

defer server.Stop()

go func() {
//nolint:errcheck
server.Serve(listener)
}()

<-ctx.Done()

shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()

factory.ServerGracefulStop(server, shutdownCtx)

return nil
}

Expand Down
20 changes: 20 additions & 0 deletions pkg/grpc/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package factory

import (
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -257,3 +258,22 @@ func ListenAndServe(r Registrator, setters ...Option) (err error) {

return server.Serve(listener)
}

// ServerGracefulStop the server with a timeout.
//
// Core gRPC doesn't support timeouts.
func ServerGracefulStop(server *grpc.Server, shutdownCtx context.Context) { //nolint:revive
stopped := make(chan struct{})

go func() {
server.GracefulStop()
close(stopped)
}()

select {
case <-shutdownCtx.Done():
server.Stop()
case <-stopped:
server.Stop()
}
}

0 comments on commit 2e79052

Please sign in to comment.