From bf9ff2b4fc05305caee2865c216f8e5d7f140c23 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 14:50:16 +0200 Subject: [PATCH] Use NWConnection instead of URLSessionWebSocketTask (#10) - Replace URLSessionWebSocketTask with NWConnection - Conform to published behavior of web browsers' WebSocket - Add support for Swift Concurrency --- .github/workflows/ci.yml | 12 + .github/workflows/swift.yml | 20 - .swiftformat | 4 +- Package.resolved | 34 - Package.swift | 27 +- README.md | 43 +- Sources/WebSocket/SystemWebSocket.swift | 647 ++++++++++++++++++ ...cketTaskCloseCode+WebSocketCloseCode.swift | 28 +- ...essionWebSocketTaskMessage+WebSocket.swift | 27 +- Sources/WebSocket/WebSocket.swift | 431 +++--------- Sources/WebSocket/WebSocketCloseCode.swift | 79 +++ Sources/WebSocket/WebSocketCloseResult.swift | 7 + Sources/WebSocket/WebSocketError.swift | 17 +- ...ocketMessage+URLSessionWebSocketTask.swift | 27 - Sources/WebSocket/WebSocketMessage.swift | 20 + Sources/WebSocket/WebSocketOptions.swift | 17 + .../Server/WebSocketServer.swift | 360 +++++----- .../WebSocketTests/SystemWebSocketTests.swift | 361 ++++++++++ ...RLSessionWebSocketTaskCloseCodeTests.swift | 7 +- Tests/WebSocketTests/WebSocketTests.swift | 285 -------- 20 files changed, 1507 insertions(+), 946 deletions(-) create mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/swift.yml delete mode 100644 Package.resolved create mode 100644 Sources/WebSocket/SystemWebSocket.swift create mode 100644 Sources/WebSocket/WebSocketCloseCode.swift create mode 100644 Sources/WebSocket/WebSocketCloseResult.swift delete mode 100644 Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift create mode 100644 Sources/WebSocket/WebSocketMessage.swift create mode 100644 Sources/WebSocket/WebSocketOptions.swift create mode 100644 Tests/WebSocketTests/SystemWebSocketTests.swift delete mode 100644 Tests/WebSocketTests/WebSocketTests.swift diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..bb81615 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,12 @@ +on: push + +jobs: + test: + runs-on: macos-12 + + steps: + - uses: actions/checkout@v2 + # Available environments: https://github.com/actions/virtual-environments/blob/main/images/macos/macos-12-Readme.md#xcode + - run: xcversion select 13.3 + - run: swift package resolve + - run: swift test --skip-update diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml deleted file mode 100644 index 05ffd87..0000000 --- a/.github/workflows/swift.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: Swift - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - build: - name: Build and Test - runs-on: macos-latest - - steps: - - uses: actions/checkout@v2 - - name: Build - run: swift build -v - - name: Run tests - run: swift test -v - diff --git a/.swiftformat b/.swiftformat index 7a40186..35d6950 100644 --- a/.swiftformat +++ b/.swiftformat @@ -1,6 +1,8 @@ --funcattributes prev-line --minversion 0.47.2 ---maxwidth 100 +--maxwidth 96 --typeattributes prev-line --wraparguments before-first +--wrapparameters before-first --wrapcollections before-first +--xcodeindentation enabled diff --git a/Package.resolved b/Package.resolved deleted file mode 100644 index bd5d45f..0000000 --- a/Package.resolved +++ /dev/null @@ -1,34 +0,0 @@ -{ - "object": { - "pins": [ - { - "package": "swift-nio", - "repositoryURL": "https://github.com/apple/swift-nio.git", - "state": { - "branch": null, - "revision": "43931b7a7daf8120a487601530c8bc03ce711992", - "version": "2.25.1" - } - }, - { - "package": "Synchronized", - "repositoryURL": "https://github.com/shareup/synchronized.git", - "state": { - "branch": null, - "revision": "f01e4a1ee5fbf586d612a8dc0bc068603f6b9450", - "version": "3.0.0" - } - }, - { - "package": "WebSocketProtocol", - "repositoryURL": "https://github.com/shareup/websocket-protocol.git", - "state": { - "branch": null, - "revision": "bd6257e3c4b23484dfc73c550b025f96c8e151f6", - "version": "2.3.1" - } - } - ] - }, - "version": 1 -} diff --git a/Package.swift b/Package.swift index 4c80549..823ab7a 100644 --- a/Package.swift +++ b/Package.swift @@ -4,36 +4,23 @@ import PackageDescription let package = Package( name: "WebSocket", platforms: [ - .macOS(.v10_15), .iOS(.v13), .tvOS(.v13), .watchOS(.v6), + .macOS(.v11), .iOS(.v14), .tvOS(.v14), .watchOS(.v7), ], products: [ .library( name: "WebSocket", targets: ["WebSocket"] - )], - dependencies: [ - .package( - name: "Synchronized", - url: "https://github.com/shareup/synchronized.git", - from: "3.0.0" ), - .package( - name: "WebSocketProtocol", - url: "https://github.com/shareup/websocket-protocol.git", - from: "2.3.2" - ), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.0.0")], + ], + dependencies: [], targets: [ .target( name: "WebSocket", - dependencies: ["Synchronized", "WebSocketProtocol"]), + dependencies: [] + ), .testTarget( name: "WebSocketTests", - dependencies: [ - .product(name: "NIO", package: "swift-nio"), - .product(name: "NIOHTTP1", package: "swift-nio"), - .product(name: "NIOWebSocket", package: "swift-nio"), - "WebSocket", - ]) + dependencies: ["WebSocket"] + ), ] ) diff --git a/README.md b/README.md index 24c97dd..78bf3ae 100644 --- a/README.md +++ b/README.md @@ -2,31 +2,32 @@ ## _(macOS, iOS, iPadOS, tvOS, and watchOS)_ -A concrete implementation of a WebSocket client implemented by wrapping Apple's `URLSessionWebSocketTask` and conforming to [`WebSocketProtocol`](https://github.com/shareup/websocket-protocol). `WebSocket` exposes a simple API and conforms to Apple's Combine [`Publisher`](https://developer.apple.com/documentation/combine/publisher). +A concrete implementation of a WebSocket client implemented by wrapping Apple's [`NWConnection`](https://developer.apple.com/documentation/network/nwconnection). + +The public "interface" of `WebSocket` is a simple struct whose public "methods" are exposed as closures. The reason for this design is to make it easy to inject fake `WebSocket`s into your code for testing purposes. + +The actual implementation is `SystemWebSocket`, but this type is not publicly accessible. Instead, you can access it via `WebSocket.system(url:)`. `SystemWebSocket` tries its best to mirror the documented behavior of web browsers' [`WebSocket`](http://developer.mozilla.org/en-US/docs/Web/API/WebSocket). Please report any deviations as bugs. + +`WebSocket` exposes a simple API, makes heavy use of [Swift Concurrency](https://developer.apple.com/documentation/swift/swift_standard_library/concurrency), and conforms to Apple's Combine [`Publisher`](https://developer.apple.com/documentation/combine/publisher). ## Usage ```swift -let socket = WebSocket(url: url(49999)) - -let sub = socket.sink( - receiveCompletion: { print("Socket closed: \(String(describing: $0))") }, - receiveValue: { (result) in - switch result { - case .success(.open): - socket.send("First message") - case .success(.string(let incoming)): - print("Received \(incoming)") - case .failure: - socket.close() - default: - break - } - } -) -defer { sub.cancel() } - -socket.connect() +// `WebSocket` starts connecting to the specified `URL` immediately. +let socket = WebSocket.system(url: url(49999)) + +// Wait for `WebSocket` to be ready to send and receive messages. +try await socket.open() + +// Send a message to the server +try await socket.send(.text("hello")) + +// Receive messages from the server +for await message in socket.messages { + print(message) +} + +try await socket.close() ``` ## Tests diff --git a/Sources/WebSocket/SystemWebSocket.swift b/Sources/WebSocket/SystemWebSocket.swift new file mode 100644 index 0000000..305a53a --- /dev/null +++ b/Sources/WebSocket/SystemWebSocket.swift @@ -0,0 +1,647 @@ +import Combine +import Foundation +import Network +import os.log + +final actor SystemWebSocket: Publisher { + typealias Output = WebSocketMessage + typealias Failure = Never + + var isOpen: Bool { get async { + guard case .open = state else { return false } + return true + } } + + var isClosed: Bool { get async { + guard case .closed = state else { return false } + return true + } } + + private let url: URL + private let options: WebSocketOptions + private var _onOpen: WebSocketOnOpen + private var _onClose: WebSocketOnClose + private var state: WebSocketState = .unopened + + private var messageIndex = 0 // Used to identify sent messages + + private let subject = PassthroughSubject() + + private let webSocketQueue: DispatchQueue = .init( + label: "app.shareup.websocket.websocketqueue", + attributes: [], + autoreleaseFrequency: .workItem, + target: .global(qos: .default) + ) + + // Deliver messages to the subscribers on a separate queue because it's a bad idea + // to let the subscribers, who could potentially be doing long-running tasks with the + // data we send them, block our network queue. + private let subscriberQueue = DispatchQueue( + label: "app.shareup.websocket.subjectqueue", + attributes: [], + target: DispatchQueue.global(qos: .default) + ) + + init( + url: URL, + options: WebSocketOptions = .init(), + onOpen: @escaping WebSocketOnOpen = {}, + onClose: @escaping WebSocketOnClose = { _ in } + ) async throws { + self.url = url + self.options = options + _onOpen = onOpen + _onClose = onClose + try connect() + } + + deinit { + switch state { + case let .connecting(connection), let .open(connection): + connection.forceCancel() + default: + break + } + } + + nonisolated func receive(subscriber: S) + where S.Input == WebSocketMessage, S.Failure == Never + { + subject + .receive(on: subscriberQueue) + .receive(subscriber: subscriber) + } + + func open(timeout: TimeInterval? = nil) async throws { + switch state { + case .open: + return + + case .closing, .closed: + throw WebSocketError.openAfterConnectionClosed + + case .unopened, .connecting: + do { + try await withThrowingTaskGroup( + of: Void + .self + ) { (group: inout ThrowingTaskGroup) in + _ = group.addTaskUnlessCancelled { [weak self] in + guard let self = self else { return } + let _timeout = UInt64(timeout ?? self.options.timeoutIntervalForRequest) + try await Task.sleep(nanoseconds: _timeout * NSEC_PER_SEC) + throw CancellationError() + } + + _ = group.addTaskUnlessCancelled { [weak self] in + guard let self = self else { return } + while await !self.isOpen { + try await Task.sleep(nanoseconds: 10 * NSEC_PER_MSEC) + } + } + + _ = try await group.next() + group.cancelAll() + } + } catch { + doClose() + throw error + } + } + } + + func send(_ message: WebSocketMessage) async throws { + // Mirrors the document behavior of JavaScript's `WebSocket` + // http://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send + switch state { + case let .open(connection): + messageIndex += 1 + + os_log( + "send: index=%d message=%s", + log: .webSocket, + type: .debug, + messageIndex, + message.debugDescription + ) + + let context = NWConnection.ContentContext( + identifier: String(messageIndex), + metadata: [message.metadata] + ) + + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + connection.send( + content: message.contentAsData, + contentContext: context, + isComplete: true, + completion: .contentProcessed { (error: NWError?) in + if let error = error { + cont.resume(throwing: error) + } else { + cont.resume() + } + } + ) + } + + case .unopened, .connecting: + os_log( + "send message while connecting: %s", + log: .webSocket, + type: .error, + message.debugDescription + ) + throw WebSocketError.sendMessageWhileConnecting + + case .closing, .closed: + os_log( + "send message while closed: %s", + log: .webSocket, + type: .debug, + message.debugDescription + ) + } + } + + func close(_ closeCode: WebSocketCloseCode = .normalClosure) async throws { + switch state { + case let .connecting(conn), let .open(conn): + os_log( + "close connection: code=%d state=%{public}s", + log: .webSocket, + type: .debug, + closeCode.rawValue, + state.description + ) + + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + conn.send( + content: nil, + contentContext: .finalMessage, + isComplete: true, + completion: .contentProcessed { (error: Error?) in + if let error = error { + cont.resume(throwing: error) + } else { + cont.resume() + } + } + ) + } + startClosing(connection: conn, error: closeCode.error) + + case .unopened, .closing, .closed: + doClose() + } + } + + func forceClose(_ closeCode: WebSocketCloseCode) { + os_log( + "force close connection: code=%d state=%{public}s", + log: .webSocket, + type: .debug, + closeCode.rawValue, + state.description + ) + + doClose() + } + + func onOpen(_ block: @escaping WebSocketOnOpen) { + _onOpen = block + } + + func onClose(_ block: @escaping WebSocketOnClose) { + _onClose = block + } +} + +private extension SystemWebSocket { + var isUnopened: Bool { + switch state { + case .unopened: return true + default: return false + } + } + + func setState(_ state: WebSocketState) async { + self.state = state + } + + func connect() throws { + precondition(isUnopened) + + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + throw WebSocketError.invalidURL(url) + } + + let parameters = try self.parameters(with: components) + let connection = NWConnection(to: .url(url), using: parameters) + state = .connecting(connection) + connection.stateUpdateHandler = connectionStateUpdateHandler + connection.start(queue: webSocketQueue) + } + + func openReadyConnection(_ connection: NWConnection) { + os_log( + "open connection: connection_state=%{public}s", + log: .webSocket, + type: .debug, + connection.state.debugDescription + ) + + state = .open(connection) + _onOpen() + connection.receiveMessage(completion: onReceiveMessage) + } + + func startClosing(connection: NWConnection, error: NWError? = nil) { + state = .closing(error) + subject.send(completion: .finished) + connection.cancel() + } + + func doClose() { + // TODO: Switch to using `state.description` + os_log( + "do close connection: state=%{public}s", + log: .webSocket, + type: .debug, + state.debugDescription + ) + + switch state { + case .closing(nil): + state = .closed(nil) + subject.send(completion: .finished) + _onClose(normalClosure) + + case let .closing(.some(err)): + state = .closed(.connectionError(err)) + subject.send(completion: .finished) + _onClose(closureWithError(err)) + + case .unopened: + state = .closed(nil) + subject.send(completion: .finished) + _onClose(abnormalClosure) + + case let .connecting(conn), let .open(conn): + state = .closed(nil) + subject.send(completion: .finished) + _onClose(abnormalClosure) + conn.forceCancel() + + case .closed: + // `PassthroughSubject` only sends completions once. + subject.send(completion: .finished) + } + } + + func doCloseWithError(_ error: WebSocketError) { + // TODO: Switch to using `state.description` + os_log( + "do close connection: state=%{public}s error=%{public}s", + log: .webSocket, + type: .debug, + state.debugDescription, + String(describing: error) + ) + + switch state { + case let .closing(.some(err)): + state = .closed(.connectionError(err)) + subject.send(completion: .finished) + _onClose(closureWithError(err)) + + case .closing(nil): + state = .closed(error) + subject.send(completion: .finished) + _onClose(closureWithError(error)) + + case .unopened: + state = .closed(error) + subject.send(completion: .finished) + _onClose(closureWithError(error)) + + case let .connecting(conn), let .open(conn): + state = .closed(nil) + subject.send(completion: .finished) + _onClose(closureWithError(error)) + conn.forceCancel() + + case .closed: + // `PassthroughSubject` only sends completions once. + subject.send(completion: .finished) + } + } +} + +private extension SystemWebSocket { + func host(with urlComponents: URLComponents) throws -> NWEndpoint.Host { + guard let host = urlComponents.host else { + throw WebSocketError.invalidURLComponents(urlComponents) + } + return NWEndpoint.Host(host) + } + + func port(with urlComponents: URLComponents) throws -> NWEndpoint.Port { + if let raw = urlComponents.port, let port = NWEndpoint.Port(rawValue: UInt16(raw)) { + return port + } else if urlComponents.scheme == "ws" { + return NWEndpoint.Port.http + } else if urlComponents.scheme == "wss" { + return NWEndpoint.Port.https + } else { + throw WebSocketError.invalidURLComponents(urlComponents) + } + } + + func parameters(with urlComponents: URLComponents) throws -> NWParameters { + let parameters: NWParameters + switch urlComponents.scheme { + case "ws": + parameters = .tcp + case "wss": + parameters = .tls + default: + throw WebSocketError.invalidURLComponents(urlComponents) + } + + let webSocketOptions = NWProtocolWebSocket.Options() + webSocketOptions.maximumMessageSize = options.maximumMessageSize + webSocketOptions.autoReplyPing = true + + parameters.defaultProtocolStack.applicationProtocols.insert(webSocketOptions, at: 0) + + return parameters + } +} + +private extension SystemWebSocket { + var connectionStateUpdateHandler: (NWConnection.State) -> Void { + { [weak self] (connectionState: NWConnection.State) in + Task { [weak self] in + guard let self = self else { return } + + let state = await self.state + + // TODO: Switch to using `state.description` + os_log( + "connection state update: connection_state=%{public}s state=%{public}s", + log: .webSocket, + type: .debug, + connectionState.debugDescription, + state.debugDescription + ) + + switch connectionState { + case .setup: + break + + case let .waiting(error): + await self.doCloseWithError(.connectionError(error)) + + case .preparing: + break + + case .ready: + switch state { + case let .connecting(conn): + await self.openReadyConnection(conn) + + case .open: + // TODO: Handle betterPathUpdate here? + break + + case .unopened, .closing, .closed: + // TODO: Switch to using `state.description` + os_log( + "unexpected connection ready: state=%{public}s", + log: .webSocket, + type: .error, + state.debugDescription + ) + } + + case let .failed(error): + switch state { + case let .connecting(conn), let .open(conn): + await self.startClosing(connection: conn, error: error) + + case .unopened, .closing, .closed: + break + } + + case .cancelled: + switch state { + case let .connecting(conn), let .open(conn): + await self.startClosing(connection: conn) + + case .unopened, .closing: + await self.doClose() + + case .closed: + break + } + + @unknown default: + assertionFailure("Unknown state '\(state)'") + } + } + } + } + + var onReceiveMessage: (Data?, NWConnection.ContentContext?, Bool, NWError?) -> Void { + { [weak self] data, context, isMessageComplete, error in + guard let self = self else { return } + guard isMessageComplete else { return } + + Task { + switch (data, context, error) { + case let (.some(data), .some(context), .none): + await self.handleSuccessfulMessage(data: data, context: context) + case let (.none, _, .some(error)): + await self.handleMessageWithError(error) + default: + await self.handleUnknownMessage(data: data, context: context, error: error) + } + } + } + } + + func handleSuccessfulMessage(data: Data, context: NWConnection.ContentContext) { + guard case let .open(connection) = state else { return } + + switch context.websocketMessageType { + case .binary: + os_log( + "receive binary: size=%d", + log: .webSocket, + type: .debug, + data.count + ) + subject.send(.data(data)) + + case .text: + guard let text = String(data: data, encoding: .utf8) else { + startClosing(connection: connection, error: .posix(.EBADMSG)) + return + } + os_log( + "receive text: content=%s", + log: .webSocket, + type: .debug, + text + ) + subject.send(.text(text)) + + case .close: + doClose() + + case .pong: + // TODO: Handle pongs at some point + break + + default: + let messageType = String(describing: context.websocketMessageType) + assertionFailure("Unexpected message type: \(messageType)") + } + + connection.receiveMessage(completion: onReceiveMessage) + } + + func handleMessageWithError(_ error: NWError) { + switch state { + case let .connecting(conn), let .open(conn): + + startClosing(connection: conn, error: error) + + case .unopened, .closing, .closed: + // TODO: Should we call `doClose()` here, instead? + break + } + } + + func handleUnknownMessage( + data: Data?, + context: NWConnection.ContentContext?, + error: NWError? + ) { + func describeInputs() -> String { + String(describing: String(data: data ?? Data(), encoding: .utf8)) + " " + + String(describing: context) + " " + String(describing: error) + } + + // TODO: Switch to using `state.description` + os_log( + "unknown message: state=%{public}s message=%s", + log: .webSocket, + type: .error, + state.debugDescription, + describeInputs() + ) + + doCloseWithError(WebSocketError.receiveUnknownMessageType) + } +} + +private extension WebSocketMessage { + var metadata: NWProtocolWebSocket.Metadata { + switch self { + case .data: return .init(opcode: .binary) + case .text: return .init(opcode: .text) + } + } + + var contentAsData: Data { + switch self { + case let .data(data): return data + case let .text(text): return Data(text.utf8) + } + } +} + +private enum WebSocketState: CustomStringConvertible, CustomDebugStringConvertible { + case unopened + case connecting(NWConnection) + case open(NWConnection) + case closing(NWError?) + case closed(WebSocketError?) + + var description: String { + switch self { + case .unopened: return "unopened" + case .connecting: return "connecting" + case .open: return "open" + case .closing: return "closing" + case .closed: return "closed" + } + } + + var debugDescription: String { + switch self { + case .unopened: return "unopened" + case let .connecting(conn): return "connecting(\(String(reflecting: conn)))" + case let .open(conn): return "open(\(String(reflecting: conn)))" + case let .closing(error): return "closing(\(error.debugDescription))" + case let .closed(error): return "closed(\(error.debugDescription))" + } + } +} + +private extension NWConnection.ContentContext { + var webSocketMetadata: NWProtocolWebSocket.Metadata? { + let definition = NWProtocolWebSocket.definition + return protocolMetadata(definition: definition) as? NWProtocolWebSocket.Metadata + } + + var websocketMessageType: NWProtocolWebSocket.Opcode? { + webSocketMetadata?.opcode + } +} + +private extension NWError { + var shouldCloseConnectionWhileConnectingOrOpen: Bool { + switch self { + case .posix(.ECANCELED), .posix(.ENOTCONN): + return false + default: + print("Unhandled error in '\(#function)': \(debugDescription)") + return true + } + } + + var closeCode: WebSocketCloseCode { + switch self { + case .posix(.ECANCELED): + return .normalClosure + default: + print("Unhandled error in '\(#function)': \(debugDescription)") + return .normalClosure + } + } +} + +private extension NWConnection.State { + var debugDescription: String { + switch self { + case .setup: return "setup" + case let .waiting(error): return "waiting(\(String(reflecting: error)))" + case .preparing: return "preparing" + case .ready: return "ready" + case let .failed(error): return "failed(\(String(reflecting: error)))" + case .cancelled: return "cancelled" + @unknown default: return "unknown" + } + } +} + +private extension Optional where Wrapped == NWError { + var debugDescription: String { + guard case let .some(error) = self else { return "" } + return String(reflecting: error) + } +} diff --git a/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift b/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift index 67f1f29..875c542 100644 --- a/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift +++ b/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift @@ -1,8 +1,32 @@ import Foundation -import WebSocketProtocol -public extension URLSessionWebSocketTask.CloseCode { +extension URLSessionWebSocketTask.CloseCode { init?(_ closeCode: WebSocketCloseCode) { self.init(rawValue: closeCode.rawValue) } } + +extension WebSocketCloseCode { + init?(_ closeCode: URLSessionWebSocketTask.CloseCode?) { + guard let closeCode = closeCode else { return nil } + self.init(rawValue: closeCode.rawValue) + } + + var urlSessionCloseCode: URLSessionWebSocketTask.CloseCode { + switch self { + case .invalid: return .invalid + case .normalClosure: return .normalClosure + case .goingAway: return .goingAway + case .protocolError: return .protocolError + case .unsupportedData: return .unsupportedData + case .noStatusReceived: return .noStatusReceived + case .abnormalClosure: return .abnormalClosure + case .invalidFramePayloadData: return .invalidFramePayloadData + case .policyViolation: return .policyViolation + case .messageTooBig: return .messageTooBig + case .mandatoryExtensionMissing: return .mandatoryExtensionMissing + case .internalServerError: return .internalServerError + case .tlsHandshakeFailure: return .tlsHandshakeFailure + } + } +} diff --git a/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift b/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift index 73946a3..ccb58b8 100644 --- a/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift +++ b/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift @@ -6,10 +6,35 @@ extension URLSessionWebSocketTask.Message: CustomDebugStringConvertible { case let .string(text): return text case let .data(data): - return "\(data.count) bytes" + return "<\(data.count) bytes>" @unknown default: assertionFailure("Unsupported message: \(self)") return "" } } } + +extension WebSocketMessage { + init(_ message: URLSessionWebSocketTask.Message) { + switch message { + case let .data(data): + self = .data(data) + case let .string(string): + self = .text(string) + @unknown default: + assertionFailure("Unknown WebSocket Message type") + self = .text("") + } + } +} + +extension Result: CustomDebugStringConvertible where Success == WebSocketMessage { + public var debugDescription: String { + switch self { + case let .success(message): + return message.debugDescription + case let .failure(error): + return error.localizedDescription + } + } +} diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index e3b87c3..e0547eb 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -1,382 +1,111 @@ import Combine import Foundation -import os.log -import Synchronized -import WebSocketProtocol -public final class WebSocket: WebSocketProtocol { - public typealias Output = Result - public typealias Failure = Swift.Error +public typealias WebSocketOnOpen = () -> Void +public typealias WebSocketOnClose = (WebSocketCloseResult) -> Void - private enum State: CustomDebugStringConvertible { - case unopened - case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) - case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) - case closing - case closed(WebSocketError) +public struct WebSocket { + /// Sets a closure to be called when the WebSocket connects successfully. + public var onOpen: (@escaping WebSocketOnOpen) async -> Void - var webSocketSessionAndTask: (URLSession, URLSessionWebSocketTask)? { - switch self { - case let .connecting(session, task, _), let .open(session, task, _): - return (session, task) - case .unopened, .closing, .closed: - return nil - } - } - - var debugDescription: String { - switch self { - case .unopened: return "unopened" - case .connecting: return "connecting" - case .open: return "open" - case .closing: return "closing" - case .closed: return "closed" - } - } - } + /// Sets a closure to be called when the WebSocket closes. + public var onClose: (@escaping WebSocketOnClose) async -> Void - /// The maximum number of bytes to buffer before the receive call fails with an error. - /// Default: 1 MiB - public var maximumMessageSize: Int = 1024 * 1024 { - didSet { sync { - guard let (_, task) = state.webSocketSessionAndTask else { return } - task.maximumMessageSize = maximumMessageSize - } } - } - - public var isOpen: Bool { sync { - guard case .open = state else { return false } - return true - } } - - public var isClosed: Bool { sync { - guard case .closed = state else { return false } - return true - } } + /// Opens the WebSocket connect with an optional timeout. After this function + /// is awaited, the WebSocket connection is open ready to be used. If the + /// connection fails or times out, an error is thrown. + public var open: (TimeInterval?) async throws -> Void - private let lock = RecursiveLock() - private func sync(_ block: () throws -> T) rethrows -> T { try lock.locked(block) } + /// Sends a close frame to the server with the given close code. + public var close: (WebSocketCloseCode) async throws -> Void - private let url: URL + /// Sends a text or binary message. + public var send: (WebSocketMessage) async throws -> Void - private let timeoutIntervalForRequest: TimeInterval - private let timeoutIntervalForResource: TimeInterval - - private var state: State = .unopened - private let subject = PassthroughSubject() - - private let subjectQueue: DispatchQueue - - public convenience init(url: URL) { - self.init(url: url, publisherQueue: DispatchQueue.global()) - } + /// Publishes messages received from WebSocket. Finishes when the + /// WebSocket connection closes. + public var messagesPublisher: () -> AnyPublisher public init( - url: URL, - timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds - timeoutIntervalForResource: TimeInterval = 604_800, // 7 days - publisherQueue: DispatchQueue = DispatchQueue.global() - ) { - self.url = url - self.timeoutIntervalForRequest = timeoutIntervalForRequest - self.timeoutIntervalForResource = timeoutIntervalForResource - subjectQueue = DispatchQueue( - label: "app.shareup.websocket.subjectqueue", - qos: .default, - autoreleaseFrequency: .workItem, - target: publisherQueue - ) - } - - deinit { - close() - } - - public func connect() { - sync { - os_log( - "connect: oldstate=%{public}@", - log: .webSocket, - type: .debug, - state.debugDescription - ) - - switch state { - case .closed, .unopened: - let delegate = WebSocketDelegate( - onOpen: onOpen, - onClose: onClose, - onCompletion: onCompletion - ) - - let config = URLSessionConfiguration.default - config.timeoutIntervalForRequest = timeoutIntervalForRequest - config.timeoutIntervalForResource = timeoutIntervalForResource - - let session = URLSession( - configuration: config, - delegate: delegate, - delegateQueue: nil - ) - - let task = session.webSocketTask(with: url) - task.maximumMessageSize = maximumMessageSize - state = .connecting(session, task, delegate) - task.resume() - receiveFromWebSocket() - - default: - break - } - } - } - - public func receive(subscriber: S) - where S.Input == Result, S.Failure == Swift.Error - { - subject.receive(subscriber: subscriber) - } - - private func receiveFromWebSocket() { - let task: URLSessionWebSocketTask? = sync { - let webSocketTask = self.state.webSocketSessionAndTask?.1 - guard let task = webSocketTask, case .running = task.state else { return nil } - return task + onOpen: @escaping (@escaping WebSocketOnOpen) async -> Void = { _ in }, + onClose: @escaping (@escaping WebSocketOnClose) async -> Void = { _ in }, + open: @escaping (TimeInterval?) async throws -> Void = { _ in }, + close: @escaping (WebSocketCloseCode) async throws -> Void = { _ in }, + send: @escaping (WebSocketMessage) async throws -> Void = { _ in }, + messagesPublisher: @escaping () -> AnyPublisher = { + Empty(completeImmediately: false).eraseToAnyPublisher() } - - task?.receive - { [weak self, weak task] (result: Result) in - guard let self = self else { return } - - let _result = result.map { WebSocketMessage($0) } - - guard task?.state == .running - else { - os_log( - "receive message in incorrect task state: message=%s taskstate=%{public}@", - log: .webSocket, - type: .debug, - _result.debugDescription, - "\(task?.state.rawValue ?? -1)" - ) - return - } - - os_log("receive: %s", log: .webSocket, type: .debug, _result.debugDescription) - self.subjectQueue.async { [weak self] in self?.subject.send(_result) } - self.receiveFromWebSocket() - } - } - - public func send( - _ string: String, - completionHandler: @escaping (Error?) -> Void = { _ in } - ) { - os_log("send: %s", log: .webSocket, type: .debug, string) - send(.string(string), completionHandler: completionHandler) - } - - public func send(_ data: Data, completionHandler: @escaping (Error?) -> Void = { _ in }) { - os_log("send: %lld bytes", log: .webSocket, type: .debug, data.count) - send(.data(data), completionHandler: completionHandler) - } - - private func send( - _ message: URLSessionWebSocketTask.Message, - completionHandler: @escaping (Error?) -> Void ) { - let task: URLSessionWebSocketTask? = sync { - guard case let .open(_, task, _) = state, task.state == .running - else { - os_log( - "send message in incorrect task state: message=%s taskstate=%{public}@", - log: .webSocket, - type: .debug, - message.debugDescription, - "\(self.state.webSocketSessionAndTask?.1.state.rawValue ?? -1)" - ) - completionHandler(WebSocketError.notOpen) - return nil - } - return task - } - - task?.send(message, completionHandler: completionHandler) - } - - public func close(_ closeCode: WebSocketCloseCode) { - let task: URLSessionWebSocketTask? = sync { - os_log( - "close: oldstate=%{public}@ code=%lld", - log: .webSocket, - type: .debug, - state.debugDescription, - closeCode.rawValue - ) - - guard let (_, task) = state.webSocketSessionAndTask, task.state == .running - else { return nil } - state = .closing - return task - } - - let code = URLSessionWebSocketTask.CloseCode(closeCode) ?? .invalid - task?.cancel(with: code, reason: nil) + self.onOpen = onOpen + self.onClose = onClose + self.open = open + self.close = close + self.send = send + self.messagesPublisher = messagesPublisher } } -private typealias OnOpenHandler = (URLSession, URLSessionWebSocketTask, String?) -> Void -private typealias OnCloseHandler = ( - URLSession, - URLSessionWebSocketTask, - URLSessionWebSocketTask.CloseCode, - Data? -) -> Void -private typealias OnCompletionHandler = (URLSession, URLSessionTask, Error?) -> Void - -private let normalCloseCodes: [URLSessionWebSocketTask.CloseCode] = [.goingAway, .normalClosure] - -// MARK: onOpen and onClose - -private extension WebSocket { - var onOpen: OnOpenHandler { - { [weak self] webSocketSession, webSocketTask, _ in - guard let self = self else { return } - - self.sync { - os_log( - "onOpen: oldstate=%{public}@", - log: .webSocket, - type: .debug, - self.state.debugDescription - ) - - guard case let .connecting(session, task, delegate) = self.state else { - os_log( - "receive onOpen callback in incorrect state: oldstate=%{public}@", - log: .webSocket, - type: .error, - self.state.debugDescription - ) - self.state = .open( - webSocketSession, - webSocketTask, - webSocketSession.delegate as! WebSocketDelegate - ) - return - } - - assert(session === webSocketSession) - assert(task === webSocketTask) - - self.state = .open(webSocketSession, webSocketTask, delegate) - } - - self.subjectQueue.async { [weak self] in self?.subject.send(.success(.open)) } - } +public extension WebSocket { + /// Calls `WebSocket.open(nil)`. + func open() async throws { + try await open(nil) } - var onClose: OnCloseHandler { - { [weak self] _, _, closeCode, reason in - guard let self = self else { return } - - self.sync { - os_log( - "onClose: oldstate=%{public}@ code=%lld", - log: .webSocket, - type: .debug, - self.state.debugDescription, - closeCode.rawValue - ) + /// Calls `WebSocket.close(closeCode: .goingAway)`. + func close() async throws { + try await close(.goingAway) + } - if case .closed = self.state { return } - self.state = .closed(WebSocketError.closed(closeCode, reason)) + /// The WebSocket's received messages as an asynchronous stream. + var messages: AsyncStream { + var cancellable: AnyCancellable? - self.subjectQueue.async { [weak self] in - if normalCloseCodes.contains(closeCode) { - self?.subject.send(completion: .finished) - } else { - self?.subject.send( - completion: .failure(WebSocketError.closed(closeCode, reason)) - ) - } + return AsyncStream { cont in + func finish() { + if cancellable != nil { + cont.finish() + cancellable = nil } } - } - } - - var onCompletion: OnCompletionHandler { - { [weak self] webSocketSession, _, error in - defer { webSocketSession.invalidateAndCancel() } - guard let self = self else { return } - os_log("onCompletion", log: .webSocket, type: .debug) - - // "The only errors your delegate receives through the error parameter - // are client-side errors, such as being unable to resolve the hostname - // or connect to the host." - // - // https://developer.apple.com/documentation/foundation/urlsessiontaskdelegate/1411610-urlsession - // - // When receiving these errors, `onClose` is not called because the connection - // was never actually opened. - guard let error = error else { return } - self.sync { - os_log( - "onCompletion: oldstate=%{public}@ error=%@", - log: .webSocket, - type: .debug, - self.state.debugDescription, - error.localizedDescription + let _cancellable = self.messagesPublisher() + .handleEvents(receiveCancel: { finish() }) + .sink( + receiveCompletion: { _ in finish() }, + receiveValue: { cont.yield($0) } ) - if case .closed = self.state { return } - self.state = .closed(.notOpen) - - self.subjectQueue.async { [weak self] in - self?.subject.send(completion: .failure(error)) - } - } + cancellable = _cancellable } } } -// MARK: URLSessionWebSocketDelegate - -private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { - private let onOpen: OnOpenHandler - private let onClose: OnCloseHandler - private let onCompletion: OnCompletionHandler - - init(onOpen: @escaping OnOpenHandler, - onClose: @escaping OnCloseHandler, - onCompletion: @escaping OnCompletionHandler) - { - self.onOpen = onOpen - self.onClose = onClose - self.onCompletion = onCompletion - super.init() - } - - func urlSession(_ webSocketSession: URLSession, - webSocketTask: URLSessionWebSocketTask, - didOpenWithProtocol protocol: String?) - { - onOpen(webSocketSession, webSocketTask, `protocol`) - } - - func urlSession(_ session: URLSession, - webSocketTask: URLSessionWebSocketTask, - didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, - reason: Data?) - { - onClose(session, webSocketTask, closeCode, reason) - } - - func urlSession(_ session: URLSession, - task: URLSessionTask, - didCompleteWithError error: Error?) - { - onCompletion(session, task, error) +public extension WebSocket { + /// System WebSocket implementation powered by the Network Framework. + static func system( + url: URL, + options: WebSocketOptions = .init(), + onOpen: @escaping WebSocketOnOpen = {}, + onClose: @escaping WebSocketOnClose = { _ in } + ) async throws -> Self { + let ws = try await SystemWebSocket( + url: url, + options: options, + onOpen: onOpen, + onClose: onClose + ) + return try await .system(ws) + } + + // This is only intended for use in tests. + internal static func system(_ ws: SystemWebSocket) async throws -> Self { + Self( + onOpen: { onOpen in await ws.onOpen(onOpen) }, + onClose: { onClose in await ws.onClose(onClose) }, + open: { timeout in try await ws.open(timeout: timeout) }, + close: { code in try await ws.close(code) }, + send: { message in try await ws.send(message) }, + messagesPublisher: { ws.eraseToAnyPublisher() } + ) } } diff --git a/Sources/WebSocket/WebSocketCloseCode.swift b/Sources/WebSocket/WebSocketCloseCode.swift new file mode 100644 index 0000000..c33552a --- /dev/null +++ b/Sources/WebSocket/WebSocketCloseCode.swift @@ -0,0 +1,79 @@ +import Foundation +import Network + +/// A code indicating why a WebSocket connection closed. +/// +/// Mirrors [URLSessionWebSocketTask](https://developer.apple.com/documentation/foundation/urlsessionwebsockettask/closecode). +public enum WebSocketCloseCode: Int, CaseIterable { + /// A code that indicates the connection is still open. + case invalid = 0 + + /// A code that indicates normal connection closure. + case normalClosure = 1000 + + /// A code that indicates an endpoint is going away. + case goingAway = 1001 + + /// A code that indicates an endpoint terminated the connection due to a protocol error. + case protocolError = 1002 + + /// A code that indicates an endpoint terminated the connection after receiving a type of data it can’t accept. + case unsupportedData = 1003 + + /// A reserved code that indicates an endpoint expected a status code and didn’t receive one. + case noStatusReceived = 1005 + + /// A reserved code that indicates the connection closed without a close control frame. + case abnormalClosure = 1006 + + /// A code that indicates the server terminated the connection because it received data inconsistent with the message’s type. + case invalidFramePayloadData = 1007 + + /// A code that indicates an endpoint terminated the connection because it received a message that violates its policy. + case policyViolation = 1008 + + /// A code that indicates an endpoint is terminating the connection because it received a message too big for it to process. + case messageTooBig = 1009 + + /// A code that indicates the client terminated the connection because the server didn’t negotiate a required extension. + case mandatoryExtensionMissing = 1010 + + /// A code that indicates the server terminated the connection because it encountered an unexpected condition. + case internalServerError = 1011 + + /// A reserved code that indicates the connection closed due to the failure to perform a TLS handshake. + case tlsHandshakeFailure = 1015 +} + +extension WebSocketCloseCode { + var error: NWError? { + switch self { + case .invalid: + return nil + case .normalClosure: + return nil + case .goingAway: + return nil + case .protocolError: + return .posix(.EPROTO) + case .unsupportedData: + return .posix(.EBADMSG) + case .noStatusReceived: + return nil + case .abnormalClosure: + return nil + case .invalidFramePayloadData: + return nil + case .policyViolation: + return nil + case .messageTooBig: + return .posix(.EMSGSIZE) + case .mandatoryExtensionMissing: + return nil + case .internalServerError: + return nil + case .tlsHandshakeFailure: + return .tls(errSSLHandshakeFail) + } + } +} diff --git a/Sources/WebSocket/WebSocketCloseResult.swift b/Sources/WebSocket/WebSocketCloseResult.swift new file mode 100644 index 0000000..048421d --- /dev/null +++ b/Sources/WebSocket/WebSocketCloseResult.swift @@ -0,0 +1,7 @@ +import Foundation + +public typealias WebSocketCloseResult = Result<(code: WebSocketCloseCode, reason: Data?), Error> + +internal let normalClosure: WebSocketCloseResult = .success((.normalClosure, nil)) +internal let abnormalClosure: WebSocketCloseResult = .success((.abnormalClosure, nil)) +internal let closureWithError: (Error) -> WebSocketCloseResult = { e in .failure(e) } diff --git a/Sources/WebSocket/WebSocketError.swift b/Sources/WebSocket/WebSocketError.swift index 7000884..a7a7cc0 100644 --- a/Sources/WebSocket/WebSocketError.swift +++ b/Sources/WebSocket/WebSocketError.swift @@ -1,8 +1,19 @@ import Foundation +import Network -public enum WebSocketError: Error { +public enum WebSocketError: Error, Equatable { case invalidURL(URL) case invalidURLComponents(URLComponents) - case notOpen - case closed(URLSessionWebSocketTask.CloseCode, Data?) + case openAfterConnectionClosed + case sendMessageWhileConnecting + case receiveMessageWhenNotOpen + case receiveUnknownMessageType + case connectionError(NWError) +} + +extension Optional where Wrapped == WebSocketError { + var debugDescription: String { + guard case let .some(error) = self else { return "" } + return String(reflecting: error) + } } diff --git a/Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift b/Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift deleted file mode 100644 index 7a96fad..0000000 --- a/Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift +++ /dev/null @@ -1,27 +0,0 @@ -import Foundation -import WebSocketProtocol - -extension WebSocketMessage { - init(_ message: URLSessionWebSocketTask.Message) { - switch message { - case let .data(data): - self = .binary(data) - case let .string(string): - self = .text(string) - @unknown default: - assertionFailure("Unknown WebSocket Message type") - self = .text("") - } - } -} - -extension Result: CustomDebugStringConvertible where Success == WebSocketMessage { - public var debugDescription: String { - switch self { - case let .success(message): - return message.debugDescription - case let .failure(error): - return error.localizedDescription - } - } -} diff --git a/Sources/WebSocket/WebSocketMessage.swift b/Sources/WebSocket/WebSocketMessage.swift new file mode 100644 index 0000000..0cca3e7 --- /dev/null +++ b/Sources/WebSocket/WebSocketMessage.swift @@ -0,0 +1,20 @@ +import Foundation +import Network + +/// An enumeration of the types of messages that can be sent or received. +public enum WebSocketMessage: CustomStringConvertible, CustomDebugStringConvertible, Hashable { + /// A WebSocket message that contains a block of data. + case data(Data) + + /// A WebSocket message that contains a UTF-8 formatted string. + case text(String) + + public var description: String { + switch self { + case let .data(data): return "\(data.count) bytes" + case let .text(text): return text + } + } + + public var debugDescription: String { description } +} diff --git a/Sources/WebSocket/WebSocketOptions.swift b/Sources/WebSocket/WebSocketOptions.swift new file mode 100644 index 0000000..d88d0e2 --- /dev/null +++ b/Sources/WebSocket/WebSocketOptions.swift @@ -0,0 +1,17 @@ +import Foundation + +public struct WebSocketOptions: Hashable { + public var maximumMessageSize: Int + public var timeoutIntervalForRequest: TimeInterval + public var timeoutIntervalForResource: TimeInterval + + public init( + maximumMessageSize: Int = 1024 * 1024, // 1 MiB + timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds + timeoutIntervalForResource: TimeInterval = 604_800 // 7 days + ) { + self.maximumMessageSize = maximumMessageSize + self.timeoutIntervalForRequest = timeoutIntervalForRequest + self.timeoutIntervalForResource = timeoutIntervalForResource + } +} diff --git a/Tests/WebSocketTests/Server/WebSocketServer.swift b/Tests/WebSocketTests/Server/WebSocketServer.swift index 5c4c44c..735eb1e 100644 --- a/Tests/WebSocketTests/Server/WebSocketServer.swift +++ b/Tests/WebSocketTests/Server/WebSocketServer.swift @@ -1,223 +1,227 @@ +import Combine import Foundation -import NIO -import NIOHTTP1 -import NIOWebSocket - -enum ReplyType { - case echo - case reply(() -> String?) - case matchReply((String) -> String?) +import Network +import WebSocket + +enum WebSocketServerError: Error { + case couldNotCreatePort(UInt16) +} + +enum WebSocketServerOutput: Hashable { + case die + case message(WebSocketMessage) } +private typealias E = WebSocketServerError + final class WebSocketServer { let port: UInt16 + let maximumMessageSize: Int + + // Publisher provided by consumers of `WebSocketServer` to provide the output + // `WebSocketServer` should send to its clients. + private let outputPublisher: AnyPublisher + private var outputPublisherSubscription: AnyCancellable? + + // Publisher the repeats everything sent to it by clients. + private let inputSubject = PassthroughSubject() + + private var listener: NWListener + private var connections: [NWConnection] = [] + + private let queue = DispatchQueue( + label: "app.shareup.websocketserverqueue", + qos: .default, + autoreleaseFrequency: .workItem, + target: .global() + ) + + init( + port: UInt16, + outputPublisher: P, + usesTLS: Bool = false, + maximumMessageSize: Int = 1024 * 1024 + ) throws where P.Output == WebSocketServerOutput, P.Failure == Error { + self.port = port + self.outputPublisher = outputPublisher.eraseToAnyPublisher() + self.maximumMessageSize = maximumMessageSize - private let replyType: ReplyType - private let eventLoopGroup: EventLoopGroup + let parameters = NWParameters(tls: usesTLS ? .init() : nil) + parameters.allowLocalEndpointReuse = true + parameters.includePeerToPeer = true + parameters.acceptLocalOnly = true - private var serverChannel: Channel? + let options = NWProtocolWebSocket.Options() + options.autoReplyPing = true + options.maximumMessageSize = maximumMessageSize - init(port: UInt16, replyProvider: ReplyType) { - self.port = port - replyType = replyProvider - eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - } + parameters.defaultProtocolStack.applicationProtocols.insert(options, at: 0) - func listen() { - do { - var addr = sockaddr_in() - addr.sin_port = in_port_t(port).bigEndian - let address = SocketAddress(addr, host: "0.0.0.0") + guard let port = NWEndpoint.Port(rawValue: port) + else { throw E.couldNotCreatePort(port) } - let bootstrap = makeBootstrap() - serverChannel = try bootstrap.bind(to: address).wait() + listener = try NWListener(using: parameters, on: port) - guard let localAddress = serverChannel?.localAddress else { - throw NIO.ChannelError.unknownLocalAddress + start() + } + + func forceClose() { + queue.sync { + connections.forEach { connection in + connection.forceCancel() } - print("WebSocketServer running on \(localAddress)") - } catch let error as NIO.IOError { - print("Failed to start server: \(error.errnoCode) '\(error.localizedDescription)'") - } catch { - print("Failed to start server: \(String(describing: error))") + connections.removeAll() + listener.cancel() } } - func close() { - do { try serverChannel?.close().wait() } - catch { print("Failed to wait on server: \(error)") } + var inputPublisher: AnyPublisher { + inputSubject.eraseToAnyPublisher() } +} - private func shouldUpgrade(channel _: Channel, - head: HTTPRequestHead) -> EventLoopFuture - { - let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil - return eventLoopGroup.next().makeSucceededFuture(headers) - } +private extension WebSocketServer { + func start() { + listener.newConnectionHandler = onNewConnection - private func upgradePipelineHandler(channel: Channel, head: HTTPRequestHead) -> NIO - .EventLoopFuture - { - head.uri.starts(with: "/socket") ? - channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : channel - .closeFuture - } + listener.stateUpdateHandler = { [weak self] state in + guard let self = self else { return } + switch state { + case .failed: + self.close() - private var replyProvider: (String) -> String? { - { [weak self] (input: String) -> String? in - guard let self = self else { return nil } - switch self.replyType { - case .echo: - return input - case let .reply(iterator): - return iterator() - case let .matchReply(matcher): - return matcher(input) + default: + break } } - } - private func makeBootstrap() -> ServerBootstrap { - let reuseAddrOpt = ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR) - return ServerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.backlog, value: 256) - .serverChannelOption(reuseAddrOpt, value: 1) - .childChannelInitializer { channel in - let connectionUpgrader = NIOWebSocketServerUpgrader( - shouldUpgrade: self.shouldUpgrade, - upgradePipelineHandler: self.upgradePipelineHandler - ) - - let config: NIOHTTPServerUpgradeConfiguration = ( - upgraders: [connectionUpgrader], - completionHandler: { _ in } - ) - - return channel.pipeline.configureHTTPServerPipeline( - position: .first, - withPipeliningAssistance: true, - withServerUpgrade: config, - withErrorHandling: true - ) - } - .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .childChannelOption(reuseAddrOpt, value: 1) - .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) + listener.start(queue: queue) } -} -private class WebSocketHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame + func broadcastMessage(_ message: WebSocketMessage) { + let context: NWConnection.ContentContext + let content: Data - private let replyProvider: (String) -> String? - private var awaitingClose = false + switch message { + case let .data(data): + let metadata: NWProtocolWebSocket.Metadata = .init(opcode: .binary) + context = .init(identifier: String(message.hashValue), metadata: [metadata]) + content = data - init(replyProvider: @escaping (String) -> String?) { - self.replyProvider = replyProvider - } + case let .text(string): + let metadata: NWProtocolWebSocket.Metadata = .init(opcode: .text) + context = .init(identifier: String(message.hashValue), metadata: [metadata]) + content = Data(string.utf8) + } - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = unwrapInboundIn(data) - - switch frame.opcode { - case .connectionClose: - onClose(context: context, frame: frame) - case .ping: - onPing(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - onText(context: context, text: text) - case .binary: - let buffer = frame.unmaskedData - var data = Data(capacity: buffer.readableBytes) - buffer.withUnsafeReadableBytes { data.append(contentsOf: $0) } - onBinary(context: context, binary: data) - default: - onError(context: context) + connections.forEach { connection in + connection.send( + content: content, + contentContext: context, + isComplete: true, + completion: .contentProcessed { [weak self] error in + guard let _ = error else { return } + self?.closeConnection(connection) + } + ) } } - private func onBinary(context: ChannelHandlerContext, binary: Data) { - do { - // Obviously, this would need to be changed to actually handle data input - if let text = String(data: binary, encoding: .utf8) { - onText(context: context, text: text) - } else { - throw NIO.IOError(errnoCode: EBADMSG, reason: "Invalid message") + func close() { + connections.forEach { closeConnection($0) } + connections.removeAll() + listener.cancel() + } + + func closeConnection(_ connection: NWConnection) { + connection.send( + content: nil, + contentContext: .finalMessage, + isComplete: true, + completion: .contentProcessed { _ in + connection.cancel() } - } catch { - onError(context: context) - } + ) } - private func onText(context: ChannelHandlerContext, text: String) { - guard let reply = replyProvider(text) else { return } - - var replyBuffer = context.channel.allocator.buffer(capacity: reply.utf8.count) - replyBuffer.writeString(reply) - - let frame = WebSocketFrame(fin: true, opcode: .text, data: replyBuffer) - - _ = context.channel.writeAndFlush(frame) + func cancelConnection(_ connection: NWConnection) { + connection.forceCancel() + connections.removeAll(where: { $0 === connection }) } - private func onPing(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data + var onNewConnection: (NWConnection) -> Void { + { [weak self] (newConnection: NWConnection) in + guard let self = self else { return } - if let maskingKey = frame.maskKey { - frameData.webSocketUnmask(maskingKey) - } + self.connections.append(newConnection) - let pong = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - context.write(wrapOutboundOut(pong), promise: nil) - } + func receive() { + newConnection.receiveMessage { [weak self] data, context, _, error in + guard let self = self else { return } + guard error == nil else { return self.closeConnection(newConnection) } - private func onClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - if awaitingClose { - // We sent the initial close and were waiting for the client's response - context.close(promise: nil) - } else { - // The close came from the client. - var data = frame.unmaskedData - let closeDataCode = data.readSlice(length: 2) ?? context.channel.allocator - .buffer(capacity: 0) - let closeFrame = WebSocketFrame( - fin: true, - opcode: .connectionClose, - data: closeDataCode - ) - _ = context.write(wrapOutboundOut(closeFrame)).map { () in - context.close(promise: nil) - } - } - } + guard let data = data, + let context = context, + let _metadata = context.protocolMetadata.first, + let metadata = _metadata as? NWProtocolWebSocket.Metadata + else { return } - private func onError(context: ChannelHandlerContext) { - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - awaitingClose = true - } + switch metadata.opcode { + case .binary: + self.inputSubject.send(.data(data)) - func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } + case .text: + if let text = String(data: data, encoding: .utf8) { + self.inputSubject.send(.text(text)) + } - func channelActive(context: ChannelHandlerContext) { - print("Channel active: \(String(describing: context.channel.remoteAddress))") - } + default: + break + } - func channelInactive(context: ChannelHandlerContext) { - print("Channel closed: \(String(describing: context.localAddress))") - } + receive() + } + } + receive() + + newConnection.stateUpdateHandler = { [weak self] state in + guard let self = self else { return } + + switch state { + case .ready: + guard self.outputPublisherSubscription == nil else { break } + self.outputPublisherSubscription = self.outputPublisher + .receive(on: self.queue) + .sink( + receiveCompletion: { [weak self] completion in + guard let self = self else { return } + guard case .failure = completion else { + self.cancelConnection(newConnection) + return + } + self.close() + }, + receiveValue: { [weak self] (output: WebSocketServerOutput) in + guard let self = self else { return } + switch output { + case .die: + self.cancelConnection(newConnection) + + case let .message(message): + self.broadcastMessage(message) + } + } + ) + + case .failed: + self.cancelConnection(newConnection) + + default: + break + } + } - func errorCaught(context: ChannelHandlerContext, error: Error) { - print("Error: \(error)") - context.close(promise: nil) + newConnection.start(queue: self.queue) + } } } diff --git a/Tests/WebSocketTests/SystemWebSocketTests.swift b/Tests/WebSocketTests/SystemWebSocketTests.swift new file mode 100644 index 0000000..44ff10e --- /dev/null +++ b/Tests/WebSocketTests/SystemWebSocketTests.swift @@ -0,0 +1,361 @@ +import Combine +@testable import WebSocket +import XCTest + +private var ports = (50000 ... 52000).map { UInt16($0) } + +// NOTE: If `WebSocketTests` is not marked as `@MainActor`, calls to +// `wait(for:timeout:)` prevent other asyncronous events from running. +// Using `await waitForExpectations(timeout:handler:)` works properly +// because it's already marked as `@MainActor`. + +@MainActor +class SystemWebSocketTests: XCTestCase { + var subject: PassthroughSubject! + + @MainActor + override func setUp() async throws { + try await super.setUp() + subject = .init() + } + + func testCanConnectToAndDisconnectFromServer() async throws { + let openEx = expectation(description: "Should have opened") + let closeEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndClient( + onOpen: { openEx.fulfill() }, + onClose: { result in + switch result { + case let .success(close): + XCTAssertEqual(.normalClosure, close.code) + XCTAssertNil(close.reason) + closeEx.fulfill() + + case let .failure(error): + XCTFail("Should not have received error: \(error)") + } + } + ) + defer { server.forceClose() } + + wait(for: [openEx], timeout: 2) + + let isOpen = await client.isOpen + XCTAssertTrue(isOpen) + + try await client.close() + wait(for: [closeEx], timeout: 2) + } + + func testErrorWhenServerIsUnreachable() async throws { + let ex = expectation(description: "Should have errored") + let (server, client) = await makeOfflineServerAndClient( + onOpen: { XCTFail("Should not have opened") }, + onClose: { result in + switch result { + case let .success(close): + XCTFail("Should not have closed successfully: \(String(reflecting: close))") + + case let .failure(error): + guard let webSocketError = error as? WebSocketError, + case let .connectionError(nwerror) = webSocketError, + case let .posix(posix) = nwerror + else { return XCTFail("Closed with incorrect error: \(error)") } + XCTAssertEqual(.ECONNREFUSED, posix) + ex.fulfill() + } + } + ) + defer { server.forceClose() } + + waitForExpectations(timeout: 2) + + let isClosed = await client.isClosed + XCTAssertTrue(isClosed) + } + + func testErrorWhenRemoteCloses() async throws { + let errorEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndClient( + onClose: { result in + switch result { + case let .success(close): + XCTFail("Should not have closed successfully: \(String(reflecting: close))") + + case let .failure(error): + guard let err = error as? WebSocketError, + case .receiveUnknownMessageType = err + else { return XCTFail("Should have received unknown message error") } + errorEx.fulfill() + } + } + ) + defer { server.forceClose() } + + try await client.open() + + subject.send(.die) + wait(for: [errorEx], timeout: 2) + } + + func testWebSocketCannotBeOpenedTwice() async throws { + var closeCount = 0 + + let firstCloseEx = expectation(description: "Should have closed once") + let secondCloseEx = expectation(description: "Should not have closed more than once") + secondCloseEx.isInverted = true + + let (server, client) = await makeServerAndClient( + onClose: { _ in + closeCount += 1 + if closeCount == 1 { + firstCloseEx.fulfill() + } else { + secondCloseEx.fulfill() + } + } + ) + defer { server.forceClose() } + + try await client.open() + + try await client.close() + wait(for: [firstCloseEx], timeout: 2) + + do { + try await client.open() + XCTFail("Should not have successfully reopened") + } catch { + guard let wserror = error as? WebSocketError, + case .openAfterConnectionClosed = wserror + else { return XCTFail("Received wrong error: \(error)") } + } + + wait(for: [secondCloseEx], timeout: 0.1) + } + + func testPushAndReceiveText() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + let sentEx = expectation(description: "Server should have received message") + let sentSub = server.inputPublisher + .sink(receiveValue: { message in + guard case let .text(text) = message + else { return XCTFail("Should have received text") } + XCTAssertEqual("hello", text) + sentEx.fulfill() + }) + defer { sentSub.cancel() } + + try await client.open() + + let receivedEx = expectation(description: "Should have received message") + let receivedSub = client.sink { message in + defer { receivedEx.fulfill() } + guard case let .text(text) = message + else { return XCTFail("Should have received text") } + XCTAssertEqual("hi, to you too!", text) + } + defer { receivedSub.cancel() } + + try await client.send(.text("hello")) + wait(for: [sentEx], timeout: 2) + subject.send(.message(.text("hi, to you too!"))) + wait(for: [receivedEx], timeout: 2) + } + + @available(iOS 15.0, macOS 12.0, *) + func testPushAndReceiveTextWithAsyncPublisher() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + try await client.open() + + try await client.send(.text("hello")) + subject.send(.message(.text("hi, to you too!"))) + + for await message in client.values { + guard case let .text(text) = message else { + XCTFail("Should have received text") + break + } + XCTAssertEqual("hi, to you too!", text) + break + } + } + + func testPushAndReceiveData() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + let sentEx = expectation(description: "Server should have received message") + let sentSub = server.inputPublisher + .sink(receiveValue: { message in + guard case let .data(data) = message + else { return XCTFail("Should have received data") } + XCTAssertEqual(Data("hello".utf8), data) + sentEx.fulfill() + }) + defer { sentSub.cancel() } + + try await client.open() + + let receivedEx = expectation(description: "Should have received message") + let receivedSub = client.sink { message in + defer { receivedEx.fulfill() } + guard case let .data(data) = message + else { return XCTFail("Should have received data") } + XCTAssertEqual(Data("hi, to you too!".utf8), data) + } + defer { receivedSub.cancel() } + + try await client.send(.data(Data("hello".utf8))) + wait(for: [sentEx], timeout: 2) + subject.send(.message(.data(Data("hi, to you too!".utf8)))) + wait(for: [receivedEx], timeout: 2) + } + + @available(iOS 15.0, macOS 12.0, *) + func testPushAndReceiveDataWithAsyncPublisher() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + try await client.open() + + try await client.send(.data(Data("hello bytes".utf8))) + subject.send(.message(.data(Data("howdy".utf8)))) + + for await message in client.values { + guard case let .data(data) = message else { + XCTFail("Should have received data") + break + } + XCTAssertEqual("howdy", String(data: data, encoding: .utf8)) + break + } + } + + func testWrappedSystemWebSocket() async throws { + let openEx = expectation(description: "Should have opened") + let closeEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndWrappedClient( + onOpen: { openEx.fulfill() }, + onClose: { result in + switch result { + case let .success((code, reason)): + XCTAssertEqual(.normalClosure, code) + XCTAssertNil(reason) + closeEx.fulfill() + case let .failure(error): + XCTFail("Should not have failed: \(error)") + } + } + ) + defer { server.forceClose() } + + var messagesToSend: [WebSocketMessage] = [ + .text("one"), + .data(Data("two".utf8)), + .text("three"), + ] + + var messagesToReceive: [WebSocketMessage] = [ + .text("one"), + .data(Data("two".utf8)), + .text("three"), + ] + + let sentSub = server.inputPublisher + .sink(receiveValue: { message in + let expected = messagesToSend.removeFirst() + XCTAssertEqual(expected, message) + }) + defer { sentSub.cancel() } + + // These two lines are redundant, but the goal + // is to test everything in `WebSocket`. + try await client.open() + wait(for: [openEx], timeout: 2) + + // These messages have to be sent after the `AsyncStream` is + // subscribed to below. So, we send them asynchronously. + let firstMessageToReceive = try XCTUnwrap(messagesToReceive.first) + let firstMessageToSend = try XCTUnwrap(messagesToSend.first) + Task.detached { + await self.subject.send(.message(firstMessageToReceive)) + try await client.send(firstMessageToSend) + } + + for await message in client.messages { + let expected = messagesToReceive.removeFirst() + XCTAssertEqual(expected, message) + + if let messageToSend = messagesToSend.first, + let messageToReceive = messagesToReceive.first + { + try await client.send(messageToSend) + subject.send(.message(messageToReceive)) + } else { + try await client.close() + } + } + + XCTAssertTrue(messagesToSend.isEmpty) + XCTAssertTrue(messagesToReceive.isEmpty) + + wait(for: [closeEx], timeout: 2) + } +} + +private let empty: Empty = Empty( + completeImmediately: false, + outputType: WebSocketServerOutput.self, + failureType: Error.self +) + +private extension SystemWebSocketTests { + func url(_ port: UInt16) -> URL { URL(string: "ws://0.0.0.0:\(port)/socket")! } + + func makeServerAndClient( + onOpen: @escaping () -> Void = {}, + onClose: @escaping (WebSocketCloseResult) -> Void = { _ in } + ) async -> (WebSocketServer, SystemWebSocket) { + let port = ports.removeFirst() + let server = try! WebSocketServer(port: port, outputPublisher: subject) + let client = try! await SystemWebSocket( + url: url(port), + onOpen: onOpen, + onClose: onClose + ) + return (server, client) + } + + func makeOfflineServerAndClient( + onOpen: @escaping () -> Void = {}, + onClose: @escaping (WebSocketCloseResult) -> Void = { _ in } + ) async -> (WebSocketServer, SystemWebSocket) { + let port = ports.removeFirst() + let server = try! WebSocketServer(port: 1, outputPublisher: empty) + let client = try! await SystemWebSocket( + url: url(port), + onOpen: onOpen, + onClose: onClose + ) + return (server, client) + } + + func makeServerAndWrappedClient( + onOpen: @escaping () -> Void = {}, + onClose: @escaping (WebSocketCloseResult) -> Void = { _ in } + ) async -> (WebSocketServer, WebSocket) { + let port = ports.removeFirst() + let server = try! WebSocketServer(port: port, outputPublisher: subject) + let client = try! await SystemWebSocket( + url: url(port), + onOpen: onOpen, + onClose: onClose + ) + return (server, try! await .system(client)) + } +} diff --git a/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift b/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift index 66df139..7ea92f7 100644 --- a/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift +++ b/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift @@ -1,5 +1,4 @@ @testable import WebSocket -import WebSocketProtocol import XCTest class URLSessionWebSocketTaskCloseCodeTests: XCTestCase { @@ -7,13 +6,15 @@ class URLSessionWebSocketTaskCloseCodeTests: XCTestCase { let urlSessionCloseCodes: [URLSessionWebSocketTask.CloseCode] = [ .invalid, .normalClosure, .goingAway, .protocolError, .unsupportedData, .noStatusReceived, .abnormalClosure, .invalidFramePayloadData, .policyViolation, - .messageTooBig, .mandatoryExtensionMissing, .internalServerError, .tlsHandshakeFailure, + .messageTooBig, .mandatoryExtensionMissing, .internalServerError, + .tlsHandshakeFailure, ] let closeCodes: [WebSocketCloseCode] = [ .invalid, .normalClosure, .goingAway, .protocolError, .unsupportedData, .noStatusReceived, .abnormalClosure, .invalidFramePayloadData, .policyViolation, - .messageTooBig, .mandatoryExtensionMissing, .internalServerError, .tlsHandshakeFailure, + .messageTooBig, .mandatoryExtensionMissing, .internalServerError, + .tlsHandshakeFailure, ] zip(urlSessionCloseCodes, closeCodes).forEach { urlSessionCloseCode, closeCode in diff --git a/Tests/WebSocketTests/WebSocketTests.swift b/Tests/WebSocketTests/WebSocketTests.swift deleted file mode 100644 index 3853d3d..0000000 --- a/Tests/WebSocketTests/WebSocketTests.swift +++ /dev/null @@ -1,285 +0,0 @@ -import Combine -@testable import WebSocket -import WebSocketProtocol -import XCTest - -private var ports = (50000 ... 52000).map { UInt16($0) } - -class WebSocketTests: XCTestCase { - func url(_ port: UInt16) -> URL { URL(string: "ws://0.0.0.0:\(port)/socket")! } - - func testCanConnectToAndDisconnectFromServer() throws { - try withServer { _, client in - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValueAndThen(WebSocketMessage.open, client.close()) - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 2) - - XCTAssertFalse(client.isOpen) - XCTAssertTrue(client.isClosed) - } - } - - func testCompleteWhenServerIsUnreachable() throws { - try withServer { server, client in - server.close() - - let sub = client.sink( - receiveCompletion: expectFailure(), - receiveValue: { result in - switch result { - case .failure: - // It's possible to receive or not receive an error. - // Clients need to be resilient in the face of this reality. - break - case let .success(message): - XCTFail("Should not have received message: \(message)") - } - } - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 0.2) - - XCTAssertTrue(client.isClosed) - } - } - - func testCompleteWhenRemoteCloses() throws { - try withServer { _, client in - var invalidUTF8Bytes = [0x192, 0x193] as [UInt16] - let bytes = withUnsafeBytes(of: &invalidUTF8Bytes) { Array($0) } - let data = Data(bytes: bytes, count: bytes.count) - - let openEx = self.expectation(description: "Should have opened") - let errorEx = self.expectation(description: "Should have erred") - - let sub = client.sink( - receiveCompletion: expectFailure(), - receiveValue: { result in - switch result { - case .success(.open): - XCTAssertTrue(client.isOpen) - XCTAssertFalse(client.isClosed) - client.send(data) - openEx.fulfill() - case let .failure(error as NSError): - XCTAssertEqual("NSPOSIXErrorDomain", error.domain) - XCTAssertEqual(57, error.code) - errorEx.fulfill() - default: - break - } - } - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 2) - - XCTAssertFalse(client.isOpen) - XCTAssertTrue(client.isClosed) - } - } - - func testEchoPush() throws { - try withEchoServer { _, client in - let message = "hello" - let completion = self.expectNoError() - - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValuesAndThen([ - .open: { client.send(message, completionHandler: completion) }, - .text(message): { client.close() }, - ]) - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 2) - } - } - - func testEchoBinaryPush() throws { - try withEchoServer { _, client in - let message = "hello" - let binary = message.data(using: .utf8)! - let completion = self.expectNoError() - - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValuesAndThen([ - .open: { client.send(binary, completionHandler: completion) }, - .text(message): { client.close() }, - ]) - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 2) - } - } - - func testJoinLobbyAndEcho() throws { - let joinPush = "[1,1,\"room:lobby\",\"phx_join\",{}]" - let echoPush1 = "[1,2,\"room:lobby\",\"echo\",{\"echo\":\"one\"}]" - let echoPush2 = "[1,3,\"room:lobby\",\"echo\",{\"echo\":\"two\"}]" - - let joinReply = "[1,1,\"room:lobby\",\"phx_reply\",{\"response\":{},\"status\":\"ok\"}]" - let echoReply1 = - "[1,2,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"one\"},\"status\":\"ok\"}]" - let echoReply2 = - "[1,3,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"two\"},\"status\":\"ok\"}]" - - let joinCompletion = expectNoError() - let echo1Completion = expectNoError() - let echo2Completion = expectNoError() - - try withReplyServer([joinReply, echoReply1, echoReply2]) { _, client in - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValuesAndThen([ - .open: { client.send(joinPush, completionHandler: joinCompletion) }, - .text(joinReply): { client.send(echoPush1, completionHandler: echo1Completion) - }, - .text(echoReply1): { client.send(echoPush2, completionHandler: echo2Completion) - }, - .text(echoReply2): { client.close() }, - ]) - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 2) - } - } - - func testCanSendFromTwoThreadsSimultaneously() throws { - let queueCount = 8 - let queues = (0 ..< queueCount).map { DispatchQueue(label: "\($0)") } - - let messageCount = 100 - let sendMessages: (WebSocket) -> Void = { client in - (0 ..< messageCount).forEach { messageIndex in - (0 ..< queueCount).forEach { queueIndex in - queues[queueIndex].async { client.send("\(queueIndex)-\(messageIndex)") } - } - } - } - - let receiveMessageEx = expectation( - description: "Should have received \(queueCount * messageCount) messages" - ) - receiveMessageEx.expectedFulfillmentCount = queueCount * messageCount - - try withEchoServer { _, client in - let sub = client.sink( - receiveCompletion: { _ in }, - receiveValue: { message in - switch message { - case .success(.open): - sendMessages(client) - case .success(.text): - receiveMessageEx.fulfill() - default: - XCTFail() - } - } - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 10) - client.close() - } - } -} - -private extension WebSocketTests { - func withServer(_ block: (WebSocketServer, WebSocket) throws -> Void) throws { - let port = ports.removeFirst() - let server = WebSocketServer(port: port, replyProvider: .reply { nil }) - let client = WebSocket(url: url(port)) - try withExtendedLifetime((server, client)) { server.listen(); try block(server, client) } - } - - func withEchoServer(_ block: (WebSocketServer, WebSocket) throws -> Void) throws { - let port = ports.removeFirst() - let server = WebSocketServer(port: port, replyProvider: .echo) - let client = WebSocket(url: url(port)) - try withExtendedLifetime((server, client)) { server.listen(); try block(server, client) } - } - - func withReplyServer( - _ replies: [String?], - _ block: (WebSocketServer, WebSocket) throws -> Void - ) throws { - let port = ports.removeFirst() - var replies = replies - let provider: () -> String? = { replies.removeFirst() } - let server = WebSocketServer(port: port, replyProvider: .reply(provider)) - let client = WebSocket(url: url(port)) - try withExtendedLifetime((server, client)) { server.listen(); try block(server, client) } - } -} - -private extension WebSocketTests { - func expectValueAndThen( - _ value: T, - _ block: @escaping @autoclosure () -> Void - ) -> (Result) -> Void { - expectValuesAndThen([value: block]) - } - - func expectValuesAndThen< - T: Hashable, - E: Error - >(_ values: [T: () -> Void]) -> (Result) -> Void { - var values = values - let expectation = self - .expectation(description: "Should have received \(String(describing: values))") - return { (result: Result) in - guard case let .success(value) = result else { - return XCTFail("Unexpected result: \(String(describing: result))") - } - - let block = values.removeValue(forKey: value) - XCTAssertNotNil(block) - block?() - - if values.isEmpty { - expectation.fulfill() - } - } - } - - func expectFinished() -> (Subscribers.Completion) -> Void { - let expectation = self.expectation(description: "Should have finished successfully") - return { completion in - guard case Subscribers.Completion.finished = completion else { return } - expectation.fulfill() - } - } - - func expectFailure() -> (Subscribers.Completion) -> Void where E: Error { - let expectation = self.expectation(description: "Should have failed") - return { completion in - guard case Subscribers.Completion.failure = completion else { return } - expectation.fulfill() - } - } - - func expectNoError() -> (Error?) -> Void { - let expectation = self.expectation(description: "Should not have had an error") - return { error in - XCTAssertNil(error) - expectation.fulfill() - } - } -}