From db8e3a7d58c040959e8a404e0614c65e738a98e1 Mon Sep 17 00:00:00 2001 From: Chris Wendt Date: Sat, 16 May 2020 16:37:22 -0600 Subject: [PATCH] clean up forked threads --- src/Network/WebSockets/Server.hs | 94 +++++++++++--------------------- websockets.cabal | 5 -- 2 files changed, 33 insertions(+), 66 deletions(-) diff --git a/src/Network/WebSockets/Server.hs b/src/Network/WebSockets/Server.hs index ed77d86..81a8bad 100644 --- a/src/Network/WebSockets/Server.hs +++ b/src/Network/WebSockets/Server.hs @@ -3,6 +3,7 @@ -- Note that in production you want to use a real webserver such as snap or -- warp. {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} module Network.WebSockets.Server ( ServerApp , runServer @@ -19,17 +20,15 @@ module Network.WebSockets.Server -------------------------------------------------------------------------------- -import Control.Concurrent (threadDelay) +import Control.Concurrent (MVar, takeMVar, tryPutMVar, + newEmptyMVar) import qualified Control.Concurrent.Async as Async -import Control.Exception (Exception, allowInterrupt, - bracket, bracketOnError, - finally, mask_, throwIO) -import Control.Monad (forever, void, when) -import qualified Data.IORef as IORef -import Data.Maybe (isJust) +import Control.Exception (Exception, bracket, + bracketOnError, finally, mask_, + throwIO) import Network.Socket (Socket) import qualified Network.Socket as S -import qualified System.Clock as Clock +import System.Timeout (timeout) -------------------------------------------------------------------------------- @@ -110,59 +109,32 @@ defaultServerOptions = ServerOptions runServerWithOptions :: ServerOptions -> ServerApp -> IO a runServerWithOptions opts app = S.withSocketsDo $ bracket - (makeListenSocket host port) - S.close $ \sock -> mask_ $ forever $ do - allowInterrupt - (conn, _) <- S.accept sock - - -- This IORef holds a time at which the thread may be killed. This time - -- can be extended by calling 'tickle'. - killRef <- IORef.newIORef =<< (+ killDelay) <$> getSecs - let tickle = IORef.writeIORef killRef =<< (+ killDelay) <$> getSecs - - -- Update the connection options to call 'tickle' whenever a pong is - -- received. - let connOpts' - | not useKiller = connOpts - | otherwise = connOpts - { connectionOnPong = tickle >> connectionOnPong connOpts - } - - -- Run the application. - appAsync <- Async.asyncWithUnmask $ \unmask -> - (unmask $ do - runApp conn connOpts' app) `finally` - (S.close conn) - - -- Install the killer if required. - when useKiller $ void $ Async.async (killer killRef appAsync) - where - host = serverHost opts - port = serverPort opts - connOpts = serverConnectionOptions opts - - -- Get the current number of seconds on some clock. - getSecs = Clock.sec <$> Clock.getTime Clock.Monotonic - - -- Parse the 'serverRequirePong' options. - useKiller = isJust $ serverRequirePong opts - killDelay = maybe 0 fromIntegral (serverRequirePong opts) - - -- Thread that reads the killRef, and kills the application if enough time - -- has passed. - killer killRef appAsync = do - killAt <- IORef.readIORef killRef - now <- getSecs - appState <- Async.poll appAsync - case appState of - -- Already finished/killed/crashed, we can give up. - Just _ -> return () - -- Should not be killed yet. Wait and try again. - Nothing | now < killAt -> do - threadDelay (fromIntegral killDelay * 1000 * 1000) - killer killRef appAsync - -- Time to kill. - _ -> Async.cancelWith appAsync PongTimeout + (makeListenSocket (serverHost opts) (serverPort opts)) + S.close $ \sock -> do + heartbeat <- newEmptyMVar :: IO (MVar ()) + + let -- Update the connection options to perform a heartbeat whenever a + -- pong is received. + connOpts = (serverConnectionOptions opts) + { connectionOnPong = tryPutMVar heartbeat () >> connectionOnPong connOpts + } + + -- Kills the thread if pong was not received within the grace period. + reaper grace appAsync = timeout (grace * 10^(6 :: Int)) (takeMVar heartbeat) >>= \case + Nothing -> appAsync `Async.cancelWith` PongTimeout + Just _ -> reaper grace appAsync + + connThread conn = case serverRequirePong opts of + Nothing -> runApp conn connOpts app + Just grace -> runApp conn connOpts app `Async.withAsync` reaper grace + + mainThread = do + (conn, _) <- S.accept sock + Async.withAsyncWithUnmask + (\unmask -> unmask (connThread conn) `finally` S.close conn) + (\_ -> mainThread) + + mask_ mainThread -------------------------------------------------------------------------------- diff --git a/websockets.cabal b/websockets.cabal index 9b705ea..41129fc 100644 --- a/websockets.cabal +++ b/websockets.cabal @@ -93,7 +93,6 @@ Library bytestring >= 0.9 && < 0.12, bytestring-builder < 0.11, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3, @@ -153,7 +152,6 @@ Test-suite websockets-tests bytestring >= 0.9 && < 0.12, bytestring-builder < 0.11, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3, @@ -182,7 +180,6 @@ Executable websockets-example bytestring >= 0.9 && < 0.12, bytestring-builder < 0.11, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3, @@ -213,7 +210,6 @@ Executable websockets-autobahn bytestring >= 0.9 && < 0.12, bytestring-builder < 0.11, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3, @@ -242,7 +238,6 @@ Benchmark bench-mask bytestring >= 0.9 && < 0.12, bytestring-builder < 0.11, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3,