Skip to content

Commit

Permalink
dns-actions: make internal loops tail recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Aug 20, 2022
1 parent 1ba0a02 commit e3e8f60
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 75 deletions.
1 change: 1 addition & 0 deletions ouroboros-network/ouroboros-network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ library
dns,
fingertree >=0.1.4.2 && <0.2,
iproute,
mtl,
nothunks,
network >=3.1.2 && <3.2,
pretty-simple,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@ import qualified Data.List.NonEmpty as NonEmpty

import Control.Exception (IOException)
import Control.Monad.Class.MonadAsync
#if !defined(mingw32_HOST_OS)

import Control.Monad.Class.MonadSTM.Strict
#endif

import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Monad.Except
import Control.Tracer (Tracer (..), traceWith)

#if !defined(mingw32_HOST_OS)

import System.Directory (getModificationTime)
#endif


import Data.IP (IP (..))
import Network.DNS (DNSError)
Expand Down Expand Up @@ -92,26 +93,26 @@ constantResource :: Applicative m => a -> Resource m err a
constantResource a = Resource (pure (Right a, constantResource a))


#if !defined(mingw32_HOST_OS)

type TimeStamp = UTCTime
#else
type TimeStamp = Time
#endif

#if defined(mingw32_HOST_OS)
-- | on Windows we will reinitialise the dns library every 60s.
--
dns_REINITIALISE_INTERVAL :: DiffTime
dns_REINITIALISE_INTERVAL = 60
#endif










getTimeStamp :: FilePath
-> IO TimeStamp
#if !defined(mingw32_HOST_OS)

getTimeStamp = getModificationTime
#else
getTimeStamp = addTime dns_REINITIALISE_INTERVAL <$> getMonotonicTime
#endif





-- | Strict version of 'Maybe' adjusted to the needs ot
Expand Down Expand Up @@ -184,39 +185,42 @@ resolverResource resolvConf = do
_ -> DNS.withResolver rs (pure . constantResource)

where
handlers :: FilePath
-> TimedResolver
-> [Handler IO
( Either (DNSorIOError IOException) DNS.Resolver
, Resource IO (DNSorIOError IOException) DNS.Resolver)]
handlers filePath tr =
[ Handler $
\(err :: IOException) ->
pure (Left (IOError err), go filePath tr)
, Handler $
\(err :: DNS.DNSError) ->
pure (Left (DNSError err), go filePath tr)
]
handlers :: [ Handler IO (Either (DNSorIOError IOException) a) ]
handlers = [ Handler $ pure . Left . IOError
, Handler $ pure . Left . DNSError
]

go :: FilePath
-> TimedResolver
-> Resource IO (DNSorIOError IOException) DNS.Resolver
go filePath tr@NoResolver = Resource $
do
modTime <- getTimeStamp filePath
resolver <- getResolver resolvConf
pure (Right resolver, go filePath (TimedResolver resolver modTime))
`catches` handlers filePath tr

go filePath tr@(TimedResolver resolver modTime) = Resource $
do
modTime' <- getTimeStamp filePath
result
<- (curry Right
<$> getTimeStamp filePath
<*> getResolver resolvConf)
`catches` handlers
case result of
Left err ->
pure (Left err, go filePath tr)
Right (modTime, resolver) -> do
pure (Right resolver, go filePath (TimedResolver resolver modTime))

go filePath tr@(TimedResolver resolver modTime) = Resource $ do
result <- runExceptT $ do
modTime' <- ExceptT $ (Right <$> getTimeStamp filePath)
`catches` handlers
if modTime' <= modTime
then pure (Right resolver, go filePath (TimedResolver resolver modTime))
then return (resolver, modTime)
else do
resolver' <- getResolver resolvConf
pure (Right resolver', go filePath (TimedResolver resolver' modTime'))
`catches` handlers filePath tr
resolver' <- ExceptT $ (Right <$> getResolver resolvConf)
`catches` handlers
return (resolver', modTime')
case result of
Left err ->
return (Left err, go filePath tr)
Right (resolver', modTime') ->
return (Right resolver', go filePath (TimedResolver resolver' modTime'))


-- | `Resource` which passes the 'DNS.Resolver' through a 'StrictTVar'. Better
Expand All @@ -225,7 +229,7 @@ resolverResource resolvConf = do
asyncResolverResource :: DNS.ResolvConf
-> IO (Resource IO (DNSorIOError IOException)
DNS.Resolver)
#if !defined(mingw32_HOST_OS)

asyncResolverResource resolvConf =
case DNS.resolvInfo resolvConf of
DNS.RCFilePath filePath -> do
Expand All @@ -234,45 +238,49 @@ asyncResolverResource resolvConf =
_ -> do
constantResource <$> getResolver resolvConf
where
handlers :: FilePath -> StrictTVar IO TimedResolver
-> [Handler IO
( Either (DNSorIOError IOException) DNS.Resolver
, Resource IO (DNSorIOError IOException) DNS.Resolver)]
handlers filePath resourceVar =
[ Handler $
\(err :: IOException) ->
pure (Left (IOError err), go filePath resourceVar)
, Handler $
\(err :: DNS.DNSError) ->
pure (Left (DNSError err), go filePath resourceVar)
]
handlers :: [ Handler IO (Either (DNSorIOError IOException) a) ]
handlers = [ Handler $ pure . Left . IOError
, Handler $ pure . Left . DNSError
]

go :: FilePath -> StrictTVar IO TimedResolver
-> Resource IO (DNSorIOError IOException) DNS.Resolver
go filePath resourceVar = Resource $ do
r <- atomically (readTVar resourceVar)
r <- readTVarIO resourceVar
case r of
NoResolver ->
do
modTime <- getModificationTime filePath
resolver <- getResolver resolvConf
atomically (writeTVar resourceVar (TimedResolver resolver modTime))
pure (Right resolver, go filePath resourceVar)
`catches` handlers filePath resourceVar

TimedResolver resolver modTime ->
do
modTime' <- getModificationTime filePath
result
<- (curry Right
<$> getTimeStamp filePath
<*> getResolver resolvConf)
`catches` handlers
case result of
Left err ->
pure (Left err, go filePath resourceVar)
Right (modTime, resolver) -> do
atomically (writeTVar resourceVar (TimedResolver resolver modTime))
pure (Right resolver, go filePath resourceVar)

TimedResolver resolver modTime -> do
result <- runExceptT $ do
modTime' <- ExceptT $ (Right <$> getTimeStamp filePath)
`catches` handlers
if modTime' <= modTime
then pure (Right resolver, go filePath resourceVar)
else do
resolver' <- getResolver resolvConf
atomically (writeTVar resourceVar (TimedResolver resolver' modTime'))
pure (Right resolver', go filePath resourceVar)
`catches` handlers filePath resourceVar
#else
asyncResolverResource resolvConf = resolverResource resolvConf
#endif
then return resolver
else do
resolver' <- ExceptT $ (Right <$> getResolver resolvConf)
`catches` handlers
lift $ atomically (writeTVar resourceVar (TimedResolver resolver' modTime'))
return resolver'
case result of
Left err ->
return (Left err, go filePath resourceVar)
Right resolver' ->
return (Right resolver', go filePath resourceVar)




-- | Like 'DNS.lookupA' but also return the TTL for the results.
--
Expand Down

0 comments on commit e3e8f60

Please sign in to comment.