From 4205d10f4c06c7e64040a4b66bed7836ecc41ca5 Mon Sep 17 00:00:00 2001 From: Jyotinder Date: Mon, 21 Oct 2024 16:37:32 +0530 Subject: [PATCH] refactor main.go --- dice.toml | 2 +- .../server/abstractserver/abstract_server.go | 7 + internal/server/httpServer.go | 2 + internal/server/resp/server.go | 2 + internal/server/server.go | 2 + internal/server/websocketServer.go | 2 + main.go | 166 +++++------------- 7 files changed, 63 insertions(+), 120 deletions(-) create mode 100644 internal/server/abstractserver/abstract_server.go diff --git a/dice.toml b/dice.toml index be63922b7..337170cfa 100644 --- a/dice.toml +++ b/dice.toml @@ -24,7 +24,7 @@ Enabled = true Port = 8379 [Performance] -WatchChanBufSize = 20000000 +WatchChanBufSize = 20000 ShardCronFrequency = 1000000000 MultiplexerPollTimeout = 100000000 MaxClients = 20000 diff --git a/internal/server/abstractserver/abstract_server.go b/internal/server/abstractserver/abstract_server.go new file mode 100644 index 000000000..e2d9d91bb --- /dev/null +++ b/internal/server/abstractserver/abstract_server.go @@ -0,0 +1,7 @@ +package abstractserver + +import "context" + +type AbstractServer interface { + Run(ctx context.Context) error +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index c15e29716..746d8485c 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/dicedb/dice/internal/server/abstractserver" "hash/crc32" "log/slog" "net/http" @@ -34,6 +35,7 @@ var unimplementedCommands = map[string]bool{ } type HTTPServer struct { + abstractserver.AbstractServer shardManager *shard.ShardManager ioChan chan *ops.StoreResponse httpServer *http.Server diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index fd0a76583..fd4ef2b93 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/dicedb/dice/internal/server/abstractserver" "log/slog" "net" "sync" @@ -36,6 +37,7 @@ const ( ) type Server struct { + abstractserver.AbstractServer Host string Port int serverFD int diff --git a/internal/server/server.go b/internal/server/server.go index 4963fa0ef..aee7e61d3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "github.com/dicedb/dice/internal/server/abstractserver" "io" "log/slog" "net" @@ -29,6 +30,7 @@ import ( ) type AsyncServer struct { + abstractserver.AbstractServer serverFD int maxClients int32 multiplexer iomultiplexer.IOMultiplexer diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index b6f1d1df6..5e665f17b 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/dicedb/dice/internal/server/abstractserver" "log/slog" "net" "net/http" @@ -35,6 +36,7 @@ var unimplementedCommandsWebsocket = map[string]bool{ } type WebsocketServer struct { + abstractserver.AbstractServer shardManager *shard.ShardManager ioChan chan *ops.StoreResponse websocketServer *http.Server diff --git a/main.go b/main.go index f997be6c4..5ca1b0af6 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "github.com/dicedb/dice/internal/server/abstractserver" "log/slog" "os" "os/signal" @@ -60,18 +61,20 @@ func main() { // Handle SIGTERM and SIGINT sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL) - var queryWatchChan chan dstore.QueryWatchEvent = nil - var cmdWatchChan chan dstore.CmdWatchEvent = nil + var ( + queryWatchChan chan dstore.QueryWatchEvent + cmdWatchChan chan dstore.CmdWatchEvent + serverErrCh = make(chan error, 2) + ) if config.EnableWatch { - queryWatchChan = make(chan dstore.QueryWatchEvent, config.DiceConfig.Performance.WatchChanBufSize) - cmdWatchChan = make(chan dstore.CmdWatchEvent, config.DiceConfig.Performance.WatchChanBufSize) + bufSize := config.DiceConfig.Performance.WatchChanBufSize + queryWatchChan = make(chan dstore.QueryWatchEvent, bufSize) + cmdWatchChan = make(chan dstore.CmdWatchEvent, bufSize) } - var serverErrCh chan error - // Get the number of available CPU cores on the machine using runtime.NumCPU(). // This determines the total number of logical processors that can be utilized // for parallel execution. Setting the maximum number of CPUs to the available @@ -79,13 +82,10 @@ func main() { // If multithreading is not enabled, server will run on a single core. var numCores int if config.EnableMultiThreading { - serverErrCh = make(chan error, 1) numCores = runtime.NumCPU() logr.Debug("The DiceDB server has started in multi-threaded mode.", slog.Int("number of cores", numCores)) } else { - serverErrCh = make(chan error, 2) logr.Debug("The DiceDB server has started in single-threaded mode.") - numCores = 1 } // The runtime.GOMAXPROCS(numCores) call limits the number of operating system @@ -107,132 +107,44 @@ func main() { var serverWg sync.WaitGroup - // Initialize the AsyncServer server - // Find a port and bind it - if !config.EnableMultiThreading { - asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, logr) - if err := asyncServer.FindPortAndBind(); err != nil { - cancel() - logr.Error("Error finding and binding port", slog.Any("error", err)) - os.Exit(1) - } - - serverWg.Add(1) - go func() { - defer serverWg.Done() - // Run the server - err := asyncServer.Run(ctx) - - // Handling different server errors - if err != nil { - if errors.Is(err, context.Canceled) { - logr.Debug("Server was canceled") - } else if errors.Is(err, diceerrors.ErrAborted) { - logr.Debug("Server received abort command") - } else { - logr.Error( - "Server error", - slog.Any("error", err), - ) - } - serverErrCh <- err - } else { - logr.Debug("Server stopped without error") - } - }() - - // Goroutine to handle shutdown signals - wg.Add(1) - go func() { - defer wg.Done() - <-sigs - asyncServer.InitiateShutdown() - cancel() - }() - - // Initialize the HTTP server - httpServer := server.NewHTTPServer(shardManager, logr) - serverWg.Add(1) - go func() { - defer serverWg.Done() - // Run the HTTP server - err := httpServer.Run(ctx) - if err != nil { - if errors.Is(err, context.Canceled) { - logr.Debug("HTTP Server was canceled") - } else if errors.Is(err, diceerrors.ErrAborted) { - logr.Debug("HTTP received abort command") - } else { - logr.Error("HTTP Server error", slog.Any("error", err)) - } - serverErrCh <- err - } else { - logr.Debug("HTTP Server stopped without error") - } - }() - } else { + if config.EnableMultiThreading { if config.EnableProfiling { stopProfiling, err := startProfiling(logr) if err != nil { logr.Error("Profiling could not be started", slog.Any("error", err)) - os.Exit(1) + sigs <- syscall.SIGKILL } - defer stopProfiling() } workerManager := worker.NewWorkerManager(config.DiceConfig.Performance.MaxClients, shardManager) - // Initialize the RESP Server respServer := resp.NewServer(shardManager, workerManager, cmdWatchChan, serverErrCh, logr) serverWg.Add(1) - go func() { - defer serverWg.Done() - // Run the server - err := respServer.Run(ctx) + go runServer(ctx, &serverWg, respServer, logr, serverErrCh) + } else { + asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, logr) + if err := asyncServer.FindPortAndBind(); err != nil { + logr.Error("Error finding and binding port", slog.Any("error", err)) + sigs <- syscall.SIGKILL + } - // Handling different server errors - if err != nil { - if errors.Is(err, context.Canceled) { - logr.Debug("Server was canceled") - } else if errors.Is(err, diceerrors.ErrAborted) { - logr.Debug("Server received abort command") - } else { - logr.Error("Server error", "error", err) - } - serverErrCh <- err - } else { - logr.Debug("Server stopped without error") - } - }() - - // Goroutine to handle shutdown signals - wg.Add(1) - go func() { - defer wg.Done() - <-sigs - respServer.Shutdown() - cancel() - }() + serverWg.Add(1) + go runServer(ctx, &serverWg, asyncServer, logr, serverErrCh) + + httpServer := server.NewHTTPServer(shardManager, logr) + serverWg.Add(1) + go runServer(ctx, &serverWg, httpServer, logr, serverErrCh) } websocketServer := server.NewWebSocketServer(shardManager, config.WebsocketPort, logr) serverWg.Add(1) + go runServer(ctx, &serverWg, websocketServer, logr, serverErrCh) + + wg.Add(1) go func() { - defer serverWg.Done() - // Run the Websocket server - err := websocketServer.Run(ctx) - if err != nil { - if errors.Is(err, context.Canceled) { - logr.Debug("Websocket Server was canceled") - } else if errors.Is(err, diceerrors.ErrAborted) { - logr.Debug("Websocket received abort command") - } else { - logr.Error("Websocket Server error", "error", err) - } - serverErrCh <- err - } else { - logr.Debug("Websocket Server stopped without error") - } + defer wg.Done() + <-sigs + cancel() }() go func() { @@ -255,6 +167,22 @@ func main() { logr.Debug("Server has shut down gracefully") } +func runServer(ctx context.Context, wg *sync.WaitGroup, srv abstractserver.AbstractServer, logr *slog.Logger, errCh chan<- error) { + defer wg.Done() + if err := srv.Run(ctx); err != nil { + switch { + case errors.Is(err, context.Canceled): + logr.Debug(fmt.Sprintf("%T was canceled", srv)) + case errors.Is(err, diceerrors.ErrAborted): + logr.Debug(fmt.Sprintf("%T received abort command", srv)) + default: + logr.Error(fmt.Sprintf("%T error", srv), slog.Any("error", err)) + } + errCh <- err + } else { + logr.Debug(fmt.Sprintf("%T stopped without error", srv)) + } +} func startProfiling(logr *slog.Logger) (func(), error) { // Start CPU profiling cpuFile, err := os.Create("cpu.prof")