Skip to content

Commit

Permalink
clean up forked threads
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismwendt committed Mar 10, 2023
1 parent dac4eae commit db8e3a7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 66 deletions.
94 changes: 33 additions & 61 deletions src/Network/WebSockets/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
-- Note that in production you want to use a real webserver such as snap or
-- warp.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
module Network.WebSockets.Server
( ServerApp
, runServer
Expand All @@ -19,17 +20,15 @@ module Network.WebSockets.Server


--------------------------------------------------------------------------------
import Control.Concurrent (threadDelay)
import Control.Concurrent (MVar, takeMVar, tryPutMVar,
newEmptyMVar)
import qualified Control.Concurrent.Async as Async
import Control.Exception (Exception, allowInterrupt,
bracket, bracketOnError,
finally, mask_, throwIO)
import Control.Monad (forever, void, when)
import qualified Data.IORef as IORef
import Data.Maybe (isJust)
import Control.Exception (Exception, bracket,
bracketOnError, finally, mask_,
throwIO)
import Network.Socket (Socket)
import qualified Network.Socket as S
import qualified System.Clock as Clock
import System.Timeout (timeout)


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -110,59 +109,32 @@ defaultServerOptions = ServerOptions
runServerWithOptions :: ServerOptions -> ServerApp -> IO a
runServerWithOptions opts app = S.withSocketsDo $
bracket
(makeListenSocket host port)
S.close $ \sock -> mask_ $ forever $ do
allowInterrupt
(conn, _) <- S.accept sock

-- This IORef holds a time at which the thread may be killed. This time
-- can be extended by calling 'tickle'.
killRef <- IORef.newIORef =<< (+ killDelay) <$> getSecs
let tickle = IORef.writeIORef killRef =<< (+ killDelay) <$> getSecs

-- Update the connection options to call 'tickle' whenever a pong is
-- received.
let connOpts'
| not useKiller = connOpts
| otherwise = connOpts
{ connectionOnPong = tickle >> connectionOnPong connOpts
}

-- Run the application.
appAsync <- Async.asyncWithUnmask $ \unmask ->
(unmask $ do
runApp conn connOpts' app) `finally`
(S.close conn)

-- Install the killer if required.
when useKiller $ void $ Async.async (killer killRef appAsync)
where
host = serverHost opts
port = serverPort opts
connOpts = serverConnectionOptions opts

-- Get the current number of seconds on some clock.
getSecs = Clock.sec <$> Clock.getTime Clock.Monotonic

-- Parse the 'serverRequirePong' options.
useKiller = isJust $ serverRequirePong opts
killDelay = maybe 0 fromIntegral (serverRequirePong opts)

-- Thread that reads the killRef, and kills the application if enough time
-- has passed.
killer killRef appAsync = do
killAt <- IORef.readIORef killRef
now <- getSecs
appState <- Async.poll appAsync
case appState of
-- Already finished/killed/crashed, we can give up.
Just _ -> return ()
-- Should not be killed yet. Wait and try again.
Nothing | now < killAt -> do
threadDelay (fromIntegral killDelay * 1000 * 1000)
killer killRef appAsync
-- Time to kill.
_ -> Async.cancelWith appAsync PongTimeout
(makeListenSocket (serverHost opts) (serverPort opts))
S.close $ \sock -> do
heartbeat <- newEmptyMVar :: IO (MVar ())

let -- Update the connection options to perform a heartbeat whenever a
-- pong is received.
connOpts = (serverConnectionOptions opts)
{ connectionOnPong = tryPutMVar heartbeat () >> connectionOnPong connOpts
}

-- Kills the thread if pong was not received within the grace period.
reaper grace appAsync = timeout (grace * 10^(6 :: Int)) (takeMVar heartbeat) >>= \case
Nothing -> appAsync `Async.cancelWith` PongTimeout
Just _ -> reaper grace appAsync

connThread conn = case serverRequirePong opts of
Nothing -> runApp conn connOpts app
Just grace -> runApp conn connOpts app `Async.withAsync` reaper grace

mainThread = do
(conn, _) <- S.accept sock
Async.withAsyncWithUnmask
(\unmask -> unmask (connThread conn) `finally` S.close conn)
(\_ -> mainThread)

mask_ mainThread


--------------------------------------------------------------------------------
Expand Down
5 changes: 0 additions & 5 deletions websockets.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ Library
bytestring >= 0.9 && < 0.12,
bytestring-builder < 0.11,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down Expand Up @@ -153,7 +152,6 @@ Test-suite websockets-tests
bytestring >= 0.9 && < 0.12,
bytestring-builder < 0.11,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down Expand Up @@ -182,7 +180,6 @@ Executable websockets-example
bytestring >= 0.9 && < 0.12,
bytestring-builder < 0.11,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down Expand Up @@ -213,7 +210,6 @@ Executable websockets-autobahn
bytestring >= 0.9 && < 0.12,
bytestring-builder < 0.11,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down Expand Up @@ -242,7 +238,6 @@ Benchmark bench-mask
bytestring >= 0.9 && < 0.12,
bytestring-builder < 0.11,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down

0 comments on commit db8e3a7

Please sign in to comment.