From 9a749715d4e8a1dc347cb841b548ed39a61ab450 Mon Sep 17 00:00:00 2001 From: Tim Perry Date: Mon, 28 Aug 2023 14:40:45 +0100 Subject: [PATCH] Fix a bug where websocket subprotocols were not forwarded --- src/rules/websockets/websocket-handlers.ts | 8 +++++++- src/util/header-utils.ts | 2 +- test/integration/websockets.spec.ts | 19 ++++++++++++------- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/rules/websockets/websocket-handlers.ts b/src/rules/websockets/websocket-handlers.ts index a40d5c97f..8688914f6 100644 --- a/src/rules/websockets/websocket-handlers.ts +++ b/src/rules/websockets/websocket-handlers.ts @@ -25,6 +25,7 @@ import { } from '../../util/request-utils'; import { findRawHeader, + findRawHeaders, objectHeadersToRaw, pairFlatRawHeaders, rawHeadersToObjectPreservingCase @@ -329,7 +330,12 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi // header object internally. const headers = rawHeadersToObjectPreservingCase(rawHeaders); - const upstreamWebSocket = new WebSocket(wsUrl, { + // Subprotocols have to be handled explicitly. WS takes control of the headers itself, + // and checks the response, so we need to parse the client headers and use them manually: + const subprotocols = findRawHeaders(rawHeaders, 'sec-websocket-protocol') + .flatMap(([_k, value]) => value.split(',').map(p => p.trim())); + + const upstreamWebSocket = new WebSocket(wsUrl, subprotocols, { maxPayload: 0, agent, lookup: getDnsLookupFunction(this.lookupOptions), diff --git a/src/util/header-utils.ts b/src/util/header-utils.ts index 06d2baaf5..f0a422296 100644 --- a/src/util/header-utils.ts +++ b/src/util/header-utils.ts @@ -30,7 +30,7 @@ export const findRawHeader = (rawHeaders: RawHeaders, targetKey: string) => export const findRawHeaderIndex = (rawHeaders: RawHeaders, targetKey: string) => rawHeaders.findIndex(([key]) => key.toLowerCase() === targetKey); -const findRawHeaders = (rawHeaders: RawHeaders, targetKey: string) => +export const findRawHeaders = (rawHeaders: RawHeaders, targetKey: string) => rawHeaders.filter(([key]) => key.toLowerCase() === targetKey); /** diff --git a/test/integration/websockets.spec.ts b/test/integration/websockets.spec.ts index 2f9afff23..aa5dfc264 100644 --- a/test/integration/websockets.spec.ts +++ b/test/integration/websockets.spec.ts @@ -146,13 +146,17 @@ nodeOnly(() => { it("forwards the incoming requests's headers", async () => { mockServer.forAnyWebSocket().thenPassThrough(); - const ws = new WebSocket(`ws://localhost:${wsPort}`, { - agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`), - headers: { - 'echo-headers': 'true', - 'Funky-HEADER-casing': 'Header-Value' + const ws = new WebSocket( + `ws://localhost:${wsPort}`, + ['subprotocol-a', 'subprotocol-b'], + { + agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`), + headers: { + 'echo-headers': 'true', + 'Funky-HEADER-casing': 'Header-Value' + } } - }); + ); const response = await new Promise((resolve, reject) => { ws.on('message', resolve); @@ -172,7 +176,8 @@ nodeOnly(() => { [ 'Sec-WebSocket-Version', '13' ], [ 'Connection', 'Upgrade' ], [ 'Upgrade', 'websocket' ], - [ 'Sec-WebSocket-Extensions', 'permessage-deflate; client_max_window_bits' ] + [ 'Sec-WebSocket-Extensions', 'permessage-deflate; client_max_window_bits' ], + [ 'Sec-WebSocket-Protocol', 'subprotocol-a,subprotocol-b' ] ]); });