diff --git a/NOnion/Http/TorHttpClient.fs b/NOnion/Http/TorHttpClient.fs index 0deb96ad..4ec8e3a9 100644 --- a/NOnion/Http/TorHttpClient.fs +++ b/NOnion/Http/TorHttpClient.fs @@ -47,8 +47,7 @@ type TorHttpClient (stream: TorStream, host: string) = let! httpResponse = receiveAll Array.empty - |> Async.StartAsTask - |> AsyncUtil.AwaitTaskWithTimeout Constants.HttpResponseTimeout + |> FSharpUtil.WithTimeout Constants.HttpResponseTimeout let header, body = let delimiter = diff --git a/NOnion/NOnion.fsproj b/NOnion/NOnion.fsproj index e7e9a9df..46dea86e 100644 --- a/NOnion/NOnion.fsproj +++ b/NOnion/NOnion.fsproj @@ -18,7 +18,6 @@ - diff --git a/NOnion/Network/TorCircuit.fs b/NOnion/Network/TorCircuit.fs index 31e2759d..d2eeddb8 100644 --- a/NOnion/Network/TorCircuit.fs +++ b/NOnion/Network/TorCircuit.fs @@ -337,8 +337,8 @@ type TorCircuit return! completionTask - |> AsyncUtil.AwaitTaskWithTimeout - Constants.CircuitOperationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout } member self.Extend (nodeDetail: CircuitNodeDetail) = @@ -412,8 +412,8 @@ type TorCircuit return! completionTask - |> AsyncUtil.AwaitTaskWithTimeout - Constants.CircuitOperationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout } @@ -483,8 +483,8 @@ type TorCircuit return! completionTask - |> AsyncUtil.AwaitTaskWithTimeout - Constants.CircuitOperationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout } member self.RegisterAsRendezvousPoint (cookie: array) = @@ -526,8 +526,8 @@ type TorCircuit return! completionTask - |> AsyncUtil.AwaitTaskWithTimeout - Constants.CircuitOperationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout } member self.ExtendAsync nodeDetail = @@ -570,8 +570,8 @@ type TorCircuit return! completionTask - |> AsyncUtil.AwaitTaskWithTimeout - Constants.CircuitOperationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout } member self.WaitingForRendezvousJoin @@ -609,7 +609,8 @@ type TorCircuit return! completionTask - |> AsyncUtil.AwaitTaskWithTimeout (TimeSpan.FromMinutes 2.) + |> Async.AwaitTask + |> FSharpUtil.WithTimeout (TimeSpan.FromMinutes 2.) } member self.Rendezvous diff --git a/NOnion/Network/TorGuard.fs b/NOnion/Network/TorGuard.fs index b1849696..2c8daba4 100644 --- a/NOnion/Network/TorGuard.fs +++ b/NOnion/Network/TorGuard.fs @@ -58,8 +58,8 @@ type TorGuard private (client: TcpClient, sslStream: SslStream) = SslProtocols.Tls12, false ) - |> AsyncUtil.AwaitNonGenericTaskWithTimeout - Constants.CircuitOperationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout ipEndpoint.ToString () |> sprintf "TorGuard: ssl connection to %s guard node authenticated" @@ -281,9 +281,7 @@ type TorGuard private (client: TcpClient, sslStream: SslStream) = TorLogger.Log "TorGuard: finished handshake process" //TODO: do security checks on handshake data } - |> Async.StartAsTask - |> AsyncUtil.AwaitNonGenericTaskWithTimeout - Constants.CircuitOperationTimeout + |> FSharpUtil.WithTimeout Constants.CircuitOperationTimeout member internal __.RegisterCircuit (circuit: ITorCircuit) : uint16 = let rec createCircuitId (retry: int) = diff --git a/NOnion/Network/TorStream.fs b/NOnion/Network/TorStream.fs index 2492ff1e..e34cc6f9 100644 --- a/NOnion/Network/TorStream.fs +++ b/NOnion/Network/TorStream.fs @@ -140,8 +140,8 @@ type TorStream (circuit: TorCircuit) = return! connectionProcessTcs - |> AsyncUtil.AwaitTaskWithTimeout - Constants.StreamCreationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.StreamCreationTimeout } member self.ConnectToDirectory () = @@ -174,8 +174,8 @@ type TorStream (circuit: TorCircuit) = return! connectionProcessTcs - |> AsyncUtil.AwaitTaskWithTimeout - Constants.StreamCreationTimeout + |> Async.AwaitTask + |> FSharpUtil.WithTimeout Constants.StreamCreationTimeout } member self.ConnectToDirectoryAsync () = diff --git a/NOnion/Utility/AsyncUtil.fs b/NOnion/Utility/AsyncUtil.fs deleted file mode 100644 index 50179b7f..00000000 --- a/NOnion/Utility/AsyncUtil.fs +++ /dev/null @@ -1,37 +0,0 @@ -namespace NOnion - -open System -open System.Threading -open System.Threading.Tasks - -module AsyncUtil = - // Snippet from http://www.fssnip.net/hx/title/AsyncAwaitTask-with-timeouts - let AwaitTaskWithTimeout (timeout: TimeSpan) (task: Task<'T>) = - async { - use cts = new CancellationTokenSource () - use timer = Task.Delay (timeout, cts.Token) - let! completed = Async.AwaitTask <| Task.WhenAny (task, timer) - - if completed = (task :> Task) then - cts.Cancel () - - let! result = Async.AwaitTask task - return result - else - return raise TimeoutErrorException - } - - let AwaitNonGenericTaskWithTimeout (timeout: TimeSpan) (task: Task) = - async { - use cts = new CancellationTokenSource () - use timer = Task.Delay (timeout, cts.Token) - let! completed = Async.AwaitTask <| Task.WhenAny (task, timer) - - if completed = task then - cts.Cancel () - - let! result = Async.AwaitTask task - return result - else - return raise TimeoutErrorException - } diff --git a/NOnion/Utility/FSharpUtil.fs b/NOnion/Utility/FSharpUtil.fs index d1236545..9fbc6384 100644 --- a/NOnion/Utility/FSharpUtil.fs +++ b/NOnion/Utility/FSharpUtil.fs @@ -9,3 +9,34 @@ module FSharpUtil = (ExceptionDispatchInfo.Capture ex).Throw () failwith "Should be unreachable" ex + + type private Either<'Val, 'Err when 'Err :> Exception> = + | FailureResult of 'Err + | SuccessfulValue of 'Val + + let WithTimeout (timeSpan: TimeSpan) (job: Async<'R>) : Async<'R> = + async { + let read = + async { + let! value = job + return value |> SuccessfulValue |> Some + } + + let delay = + async { + let total = int timeSpan.TotalMilliseconds + do! Async.Sleep total + return FailureResult <| TimeoutException () |> Some + } + + let! dummyOption = Async.Choice ([ read; delay ]) + + match dummyOption with + | Some theResult -> + match theResult with + | SuccessfulValue r -> return r + | FailureResult _ -> return raise <| TimeoutErrorException + | None -> + // none of the jobs passed to Async.Choice returns None + return failwith "unreachable" + }