Skip to content

Commit

Permalink
swift-timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
swhitty committed Sep 8, 2024
1 parent 66636b3 commit 47e1616
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 97 deletions.
80 changes: 32 additions & 48 deletions FlyingSocks/Sources/Task+Timeout.swift
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//
// TaskTimeout.swift
// TaskTimeout
// swift-timeout
//
// Created by Simon Whitty on 31/08/2024.
// Copyright 2024 Simon Whitty
//
// Distributed under the permissive MIT license
// Get the latest version from here:
//
// https://github.com/swhitty/TaskTimeout
// https://github.com/swhitty/swift-timeout
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand All @@ -31,11 +31,11 @@

import Foundation

package struct TimeoutError: LocalizedError {
package var errorDescription: String?
public struct TimeoutError: LocalizedError {
public var errorDescription: String?

package init(timeout: TimeInterval) {
self.errorDescription = "Task timed out before completion. Timeout: \(timeout) seconds."
init(_ description: String) {
self.errorDescription = description
}
}

Expand All @@ -45,34 +45,31 @@ package func withThrowingTimeout<T>(
seconds: TimeInterval,
body: () async throws -> sending T
) async throws -> sending T {
let transferringBody = { try await Transferring(body()) }
typealias NonSendableClosure = () async throws -> Transferring<T>
typealias SendableClosure = @Sendable () async throws -> Transferring<T>
return try await withoutActuallyEscaping(transferringBody) {
(_ fn: @escaping NonSendableClosure) async throws -> Transferring<T> in
let sendableFn = unsafeBitCast(fn, to: SendableClosure.self)
return try await _withThrowingTimeout(isolation: isolation, seconds: seconds, body: sendableFn)
}.value
}

// Sendable
private func _withThrowingTimeout<T: Sendable>(
isolation: isolated (any Actor)? = #isolation,
seconds: TimeInterval,
body: @Sendable @escaping () async throws -> T
) async throws -> T {
try await withThrowingTaskGroup(of: T.self, isolation: isolation) { group in
group.addTask {
try await body()
try await withoutActuallyEscaping(body) { escapingBody in
let bodyTask = Task {
defer { _ = isolation }
return try await Transferring(escapingBody())
}
group.addTask {
let timeoutTask = Task {
defer { bodyTask.cancel() }
try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
throw TimeoutError(timeout: seconds)
throw TimeoutError("Task timed out before completion. Timeout: \(seconds) seconds.")
}
let success = try await group.next()!
group.cancelAll()
return success
}

let bodyResult = await withTaskCancellationHandler {
await bodyTask.result
} onCancel: {
bodyTask.cancel()
}
timeoutTask.cancel()

if case .failure(let timeoutError) = await timeoutTask.result,
timeoutError is TimeoutError {
throw timeoutError
} else {
return try bodyResult.get()
}
}.value
}
#else
package func withThrowingTimeout<T>(
Expand Down Expand Up @@ -100,7 +97,7 @@ private func _withThrowingTimeout<T: Sendable>(
}
group.addTask {
try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
throw TimeoutError(timeout: seconds)
throw TimeoutError("Task timed out before completion. Timeout: \(seconds) seconds.")
}
let success = try await group.next()!
group.cancelAll()
Expand Down Expand Up @@ -132,26 +129,13 @@ package extension Task {
}
case .afterTimeout(let seconds):
if seconds > 0 {
return try await getValue(cancellingAfter: seconds)
return try await withThrowingTimeout(seconds: seconds) {
try await getValue(cancelling: .whenParentIsCancelled)
}
} else {
cancel()
return try await value
}
}
}

private func getValue(cancellingAfter seconds: TimeInterval) async throws -> Success {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
_ = try await getValue(cancelling: .whenParentIsCancelled)
}
group.addTask {
try await Task<Never, Never>.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
throw TimeoutError(timeout: seconds)
}
_ = try await group.next()!
group.cancelAll()
return try await value
}
}
}
2 changes: 1 addition & 1 deletion FlyingSocks/Tests/SocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct SocketTests {
try s1.close()
try s2.close()

#expect(throws: SocketError.disconnected) {
#expect(throws: (any Error).self) {
try s1.read()
}
}
Expand Down
79 changes: 50 additions & 29 deletions FlyingSocks/Tests/Task+TimeoutTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct TaskTimeoutTests {
func timeoutThrowsError_WhenTimeoutExpires() async {
// given
let task = Task<Void, any Error>(timeout: 0.01) {
try? await Task.sleep(seconds: 10)
try await Task.sleep(seconds: 10)
}

// then
Expand Down Expand Up @@ -141,30 +141,26 @@ struct TaskTimeoutTests {
)
}

@MainActor
@Test
@Test @MainActor
func mainActor_ReturnsValue() async throws {
let val = try await withThrowingTimeout(seconds: 1) {
#if compiler(>=5.10)
MainActor.assertIsolated()
#endif
try await Task.sleep(nanoseconds: 1_000)
MainActor.assertIsolated()
return "Fish"
}
#expect(val == "Fish")
}

@Test
func mainActorThrowsError_WhenTimeoutExpires() async throws {
let task = Task { @MainActor in
func mainActorThrowsError_WhenTimeoutExpires() async {
await #expect(throws: TimeoutError.self) { @MainActor in
try await withThrowingTimeout(seconds: 0.05) {
MainActor.assertIsolated()
try? await Task.sleep(nanoseconds: 60_000_000_000)
defer { MainActor.assertIsolated() }
try await Task.sleep(nanoseconds: 60_000_000_000)
}
}

await #expect(throws: TimeoutError.self) {
try await task.value
}
}

@Test
Expand All @@ -186,17 +182,32 @@ struct TaskTimeoutTests {

@Test
func actor_ReturnsValue() async throws {
let val = try await TestActor().returningString("Fish")
#expect(val == "Fish")
#expect(
try await TestActor("Fish").returningValue() == "Fish"
)
}

@Test
func actorThrowsError_WhenTimeoutExpires() async {
await #expect(throws: TimeoutError.self) {
_ = try await TestActor().returningString(
after: 60,
timeout: 0.05
)
try await withThrowingTimeout(seconds: 0.05) {
try await TestActor().returningValue(after: 60, timeout: 0.05)
}
}
}

@Test
func timeout_cancels() async {
let task = Task {
try await withThrowingTimeout(seconds: 1) {
try await Task.sleep(nanoseconds: 1_000_000_000)
}
}

task.cancel()

await #expect(throws: CancellationError.self) {
try await task.value
}
}
}
Expand All @@ -206,9 +217,15 @@ extension Task where Success: Sendable, Failure == any Error {
// Start a new Task with a timeout.
init(priority: TaskPriority? = nil, timeout: TimeInterval, operation: @escaping @Sendable () async throws -> Success) {
self = Task(priority: priority) {
try await withThrowingTimeout(seconds: timeout) {
try await operation()
do {
return try await withThrowingTimeout(seconds: timeout) {
try await operation()
}
} catch {
print(error)
throw error
}

}
}
}
Expand All @@ -227,19 +244,23 @@ public struct NonSendable<T> {
}
}

private final actor TestActor {
private final actor TestActor<T: Sendable> {

private var value: T

func returningString(_ string: String = "Fish", after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> String {
try await returningValue(string, after: sleep, timeout: timeout)
init(_ value: T) {
self.value = value
}

func returningValue<T: Sendable>(_ value: T, after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
init() where T == String {
self.init("fish")
}

func returningValue(after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
try await withThrowingTimeout(seconds: timeout) {
if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) {
assertIsolated()
}
try await Task.sleep(seconds: sleep)
return value
try await Task.sleep(nanoseconds: UInt64(sleep * 1_000_000_000))
self.assertIsolated()
return self.value
}
}
}
71 changes: 52 additions & 19 deletions FlyingSocks/XCTests/Task+TimeoutTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ final class TaskTimeoutTests: XCTestCase {
func testTimeoutThrowsError_WhenTimeoutExpires() async {
// given
let task = Task<Void, any Error>(timeout: 0.5) {
try? await Task.sleep(seconds: 10)
try await Task.sleep(seconds: 10)
}

// then
Expand Down Expand Up @@ -146,9 +146,9 @@ final class TaskTimeoutTests: XCTestCase {
@MainActor
func testMainActor_ReturnsValue() async throws {
let val = try await withThrowingTimeout(seconds: 1) {
#if compiler(>=5.10)
MainActor.assertIsolated()
#endif
MainActor.safeAssertIsolated()
try await Task.sleep(nanoseconds: 1_000)
MainActor.safeAssertIsolated()
return "Fish"
}
XCTAssertEqual(val, "Fish")
Expand All @@ -158,10 +158,9 @@ final class TaskTimeoutTests: XCTestCase {
func testMainActorThrowsError_WhenTimeoutExpires() async {
do {
try await withThrowingTimeout(seconds: 0.05) {
#if compiler(>=5.10)
MainActor.assertIsolated()
#endif
try? await Task.sleep(nanoseconds: 60_000_000_000)
MainActor.safeAssertIsolated()
defer { MainActor.safeAssertIsolated() }
try await Task.sleep(nanoseconds: 60_000_000_000)
}
XCTFail("Expected Error")
} catch {
Expand All @@ -185,13 +184,13 @@ final class TaskTimeoutTests: XCTestCase {
}

func testActor_ReturnsValue() async throws {
let val = try await TestActor().returningString("Fish")
let val = try await TestActor("Fish").returningValue()
XCTAssertEqual(val, "Fish")
}

func testActorThrowsError_WhenTimeoutExpires() async {
do {
_ = try await TestActor().returningString(
_ = try await TestActor().returningValue(
after: 60,
timeout: 0.05
)
Expand All @@ -200,6 +199,23 @@ final class TaskTimeoutTests: XCTestCase {
XCTAssertTrue(error is TimeoutError)
}
}

func testTimeout_Cancels() async {
let task = Task {
try await withThrowingTimeout(seconds: 1) {
try await Task.sleep(nanoseconds: 1_000_000_000)
}
}

task.cancel()

do {
_ = try await task.value
XCTFail("Expected Error")
} catch {
XCTAssertTrue(error is CancellationError)
}
}
}

extension Task where Success: Sendable, Failure == any Error {
Expand Down Expand Up @@ -228,19 +244,36 @@ public struct NonSendable<T> {
}
}

private final actor TestActor {
private final actor TestActor<T: Sendable> {

private var value: T

init(_ value: T) {
self.value = value
}

func returningString(_ string: String = "Fish", after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> String {
try await returningValue(string, after: sleep, timeout: timeout)
init() where T == String {
self.init("fish")
}

func returningValue<T: Sendable>(_ value: T, after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
func returningValue(after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
try await withThrowingTimeout(seconds: timeout) {
if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) {
assertIsolated()
}
try await Task.sleep(seconds: sleep)
return value
try await Task.sleep(nanoseconds: UInt64(sleep * 1_000_000_000))
#if compiler(>=5.10)
self.assertIsolated()
#endif
return self.value
}
}
}

private extension MainActor {

static func safeAssertIsolated() {
#if compiler(>=5.10)
assertIsolated()
#else
precondition(Thread.isMainThread)
#endif
}
}

0 comments on commit 47e1616

Please sign in to comment.