Skip to content

Commit

Permalink
fix(realtime): web socket message listener doesn't stop (#284)
Browse files Browse the repository at this point in the history
* fix(realtime): web socket message listener doesn't stop

* Do not expose task
  • Loading branch information
grdsdev authored Mar 28, 2024
1 parent a701bb1 commit 0b19580
Showing 1 changed file with 102 additions and 31 deletions.
133 changes: 102 additions & 31 deletions Sources/Realtime/V2/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @
private let logger: (any SupabaseLogger)?

struct MutableState {
var task: URLSessionWebSocketTask?
var continuation: AsyncStream<ConnectionStatus>.Continuation?
var stream: SocketStream?
}

let mutableState = LockIsolated(MutableState())
Expand All @@ -56,8 +56,9 @@ final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @
func connect() -> AsyncStream<ConnectionStatus> {
mutableState.withValue { state in
let session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil)
state.task = session.webSocketTask(with: realtimeURL)
state.task?.resume()
let task = session.webSocketTask(with: realtimeURL)
state.stream = SocketStream(task: task)
task.resume()

let (stream, continuation) = AsyncStream<ConnectionStatus>.makeStream()
state.continuation = continuation
Expand All @@ -67,49 +68,48 @@ final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @

func disconnect(closeCode: URLSessionWebSocketTask.CloseCode) {
mutableState.withValue { state in
state.task?.cancel(with: closeCode, reason: nil)
state.stream?.cancel(with: closeCode)
}
}

func receive() -> AsyncThrowingStream<RealtimeMessageV2, any Error> {
let (stream, continuation) = AsyncThrowingStream<RealtimeMessageV2, any Error>.makeStream()

Task {
while let message = try await mutableState.task?.receive() {
do {
switch message {
case let .string(stringMessage):
logger?.verbose("Received message: \(stringMessage)")

guard let data = stringMessage.data(using: .utf8) else {
throw RealtimeError("Expected a UTF8 encoded message.")
}

let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data)
continuation.yield(message)

case .data:
fallthrough
default:
throw RealtimeError("Unsupported message type.")
mutableState.withValue { mutableState in
guard let stream = mutableState.stream else {
return .finished(
throwing: RealtimeError(
"receive() called before connect(). Make sure to call `connect()` before calling `receive()`."
)
)
}

return stream.map { message in
switch message {
case let .string(stringMessage):
self.logger?.verbose("Received message: \(stringMessage)")

guard let data = stringMessage.data(using: .utf8) else {
throw RealtimeError("Expected a UTF8 encoded message.")
}
} catch {
continuation.finish(throwing: error)

let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data)
return message

case .data:
fallthrough
default:
throw RealtimeError("Unsupported message type.")
}
}

continuation.finish()
.eraseToThrowingStream()
}

return stream
}

func send(_ message: RealtimeMessageV2) async throws {
let data = try JSONEncoder().encode(message)
let string = String(decoding: data, as: UTF8.self)

logger?.verbose("Sending message: \(string)")
try await mutableState.task?.send(.string(string))
try await mutableState.stream?.send(.string(string))
}

// MARK: - URLSessionWebSocketDelegate
Expand Down Expand Up @@ -144,3 +144,74 @@ final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @
mutableState.continuation?.yield(.error(error))
}
}

typealias WebSocketStream = AsyncThrowingStream<URLSessionWebSocketTask.Message, any Error>

final class SocketStream: AsyncSequence, Sendable {
typealias AsyncIterator = WebSocketStream.Iterator
typealias Element = URLSessionWebSocketTask.Message

struct MutableState {
var continuation: WebSocketStream.Continuation?
var stream: WebSocketStream?
}

private let task: URLSessionWebSocketTask
private let mutableState = LockIsolated(MutableState())

private func makeStreamIfNeeded() -> WebSocketStream {
mutableState.withValue { state in
if let stream = state.stream {
return stream
}

let stream = WebSocketStream { continuation in
state.continuation = continuation
waitForNextValue()
}

state.stream = stream
return stream
}
}

private func waitForNextValue() {
guard task.closeCode == .invalid else {
mutableState.continuation?.finish()
return
}

task.receive { [weak self] result in
guard let continuation = self?.mutableState.continuation else { return }

do {
let message = try result.get()
continuation.yield(message)
self?.waitForNextValue()
} catch {
continuation.finish(throwing: error)
}
}
}

init(task: URLSessionWebSocketTask) {
self.task = task
}

deinit {
mutableState.continuation?.finish()
}

func makeAsyncIterator() -> WebSocketStream.Iterator {
makeStreamIfNeeded().makeAsyncIterator()
}

func cancel(with closeCode: URLSessionWebSocketTask.CloseCode = .goingAway) {
task.cancel(with: closeCode, reason: nil)
mutableState.continuation?.finish()
}

func send(_ message: URLSessionWebSocketTask.Message) async throws {
try await task.send(message)
}
}

0 comments on commit 0b19580

Please sign in to comment.