From 52d253f2cfdd554bf528c3cbb05ac4a0a106659f Mon Sep 17 00:00:00 2001 From: hippoz <10706925-hippoz@users.noreply.gitlab.com> Date: Tue, 8 Aug 2023 15:36:51 +0300 Subject: [PATCH] greatly refactor gateway --- src/gateway/index.ts | 634 +++++++++++++++--------------- src/types/gatewayclientstate.d.ts | 11 - src/types/ws.d.ts | 7 - 3 files changed, 313 insertions(+), 339 deletions(-) delete mode 100644 src/types/gatewayclientstate.d.ts delete mode 100644 src/types/ws.d.ts diff --git a/src/gateway/index.ts b/src/gateway/index.ts index 3284a5d..27a1d91 100644 --- a/src/gateway/index.ts +++ b/src/gateway/index.ts @@ -1,6 +1,6 @@ import { Server } from "node:http"; import { performance } from "node:perf_hooks"; -import { WebSocketServer, WebSocket } from "ws"; +import WebSocket, { WebSocketServer } from "ws"; import { decodeTokenOrNull, getPublicUserObject } from "../auth"; import { query } from "../database"; import { gatewayErrors } from "../errors"; @@ -15,16 +15,21 @@ const GATEWAY_PING_INTERVAL = 40000; const MAX_CLIENT_MESSAGES_PER_BATCH = 30; // TODO: how well does this work for weak connections? const MAX_GATEWAY_SESSIONS_PER_USER = 5; + // mapping between a dispatch id and a websocket client -const dispatchChannels = new Map>(); +const dispatchChannels = new Map>(); // mapping between a user id and the websocket sessions it has -const sessionsByUserId = new Map(); +const sessionsByUserId = new Map(); // mapping between a dispatch id and a temporary handler const dispatchTemporary = new Map void>>(); -export function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) { +// all clients +const gatewayClients = new Set(); + + +function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) { channels.forEach(c => { if (!dispatchTemporary.get(c)) { dispatchTemporary.set(c, new Set()); @@ -64,65 +69,17 @@ export function waitForEvent(channels: string[], timeout: number) { }); } -function clientSubscribe(ws: WebSocket, dispatchChannel: string) { - ws.state.dispatchChannels.add(dispatchChannel); - if (!dispatchChannels.get(dispatchChannel)) { - dispatchChannels.set(dispatchChannel, new Set()); - } - - dispatchChannels.get(dispatchChannel)?.add(ws); -} - -function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) { - if (!ws.state) return; - - ws.state.dispatchChannels.delete(dispatchChannel); - - const set = dispatchChannels.get(dispatchChannel); - if (!set) return; - - set.delete(ws); - if (set.size < 1) { - dispatchChannels.delete(dispatchChannel); - } -} - -export function dispatchChannelSubscribe(target: string, dispatchChannel: string) { - const set = dispatchChannels.get(target); - if (!set) return; - - set.forEach(c => { - clientSubscribe(c, dispatchChannel); - }); -} - -function clientUnsubscribeAll(ws: WebSocket) { - if (!ws.state) return; - - ws.state.dispatchChannels.forEach(e => { - const set = dispatchChannels.get(e); - if (!set) return; - - set.delete(ws); - if (set && set.size < 1) { - dispatchChannels.delete(e); - } - }); - - ws.state.dispatchChannels = new Set(); -} - -export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket | null) => GatewayPayload)) { +export function dispatch(channel: string, message: GatewayPayload | ((ws: GatewayClient | null) => GatewayPayload)) { const members = dispatchChannels.get(channel); if (!members) return; members.forEach(e => { - if (e.state.ready) { + if (e.ready) { let data = message; if (typeof message === "function") { data = message(e); } - e.send(JSON.stringify(data)); + e.send(data); } }); @@ -139,14 +96,34 @@ export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSoc } } -function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { - return ws.close(code, message); +export function dispatchChannelSubscribe(target: string, dispatchChannel: string) { + const set = dispatchChannels.get(target); + set?.forEach(c => c.subscribe(dispatchChannel)); } -function closeWithBadPayload(ws: WebSocket, hint: string) { - return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`); -} +function getInitialPresenceEntries(): GatewayPresenceEntry[] { + // The initial presence entries are sent right when the user connects. + // In the future, each user will have their own list of channels that they can join and leave. + // In that case, we will send the presence entries to a certain user only for the channels they're in. + + const entries: GatewayPresenceEntry[] = []; + + sessionsByUserId.forEach((clients: GatewayClient[], userId: number) => { + if (clients.length < 1) + return; + + const firstClient = clients[0]; + if (firstClient.ready && firstClient.user) { + const entry = firstClient.getPresenceEntry(GatewayPresenceStatus.Online); + if (entry) { + entries.push(entry); + } + } + }); + + return entries; +} function parseJsonOrNull(payload: string): any { try { return JSON.parse(payload); @@ -155,8 +132,6 @@ function parseJsonOrNull(payload: string): any { } } -// The function below ensures `payload` is of the GatewayPayload -// interface payload. If it does not match, null is returned. function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null { if (!payload) { return null; @@ -183,288 +158,305 @@ function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null { return asPayload; } -function sendPayload(ws: WebSocket, payload: GatewayPayload) { - ws.send(JSON.stringify(payload)); -} -function getPresenceEntryForConnection(ws: WebSocket, status: GatewayPresenceStatus): GatewayPresenceEntry | null { - if (!ws.state || !ws.state.user) { - return null; +class GatewayClient { + ws: WebSocket; + user?: User; + ready: boolean; + alive: boolean; + lastAliveCheck: number; + clientDispatchChannels: Set; + messagesSinceLastCheck: number; + bridgesTo?: string; + privacy?: string; + terms?: string; + + constructor(ws: WebSocket) { + this.ws = ws; + this.user = undefined; + this.ready = false; + this.alive = false; + this.lastAliveCheck = performance.now(); + this.clientDispatchChannels = new Set(); + this.messagesSinceLastCheck = 0; + this.bridgesTo = undefined; + this.privacy = undefined; + this.terms = undefined; + + gatewayClients.add(this); + this.ws.on("close", this.handleClose.bind(this)); + this.ws.on("message", this.handleMessage.bind(this)); } - const entry: GatewayPresenceEntry = { - user: { - id: ws.state.user.id, - username: ws.state.user.username, - avatar: ws.state.user.avatar - }, - status - }; - if (typeof ws.state.bridgesTo === "string") { - entry.bridgesTo = ws.state.bridgesTo; - } - if (typeof ws.state.privacy === "string") { - entry.privacy = ws.state.privacy; - } - if (typeof ws.state.terms === "string") { - entry.terms = ws.state.terms; + greet() { + this.send({ + t: GatewayPayloadType.Hello, + d: { + pingInterval: GATEWAY_PING_INTERVAL + } + }); } - return entry; -} + subscribe(channel: string) { + this.clientDispatchChannels.add(channel); + if (!dispatchChannels.get(channel)) { + dispatchChannels.set(channel, new Set()); + } + + dispatchChannels.get(channel)?.add(this); + } -// The initial presence entries are sent right when the user connects. -// In the future, each user will have their own list of channels that they can join and leave. -// In that case, we will send the presence entries to a certain user only for the channels they're in. -function getInitialPresenceEntries(): GatewayPresenceEntry[] { - const entries: GatewayPresenceEntry[] = []; + unsubscribeAll() { + this.clientDispatchChannels.forEach((channel) => { + const set = dispatchChannels.get(channel); + if (!set) return; + + set.delete(this); + if (set && set.size < 1) { + dispatchChannels.delete(channel); + } + }); + this.clientDispatchChannels.clear(); + } - sessionsByUserId.forEach((wsList: WebSocket[], userId: number) => { - if (wsList.length < 1) - return; + send(payload: object) { + this.ws.send(JSON.stringify(payload)); + } - const firstWs = wsList[0]; - if (firstWs.state.ready && firstWs.state.user) { - const entry = getPresenceEntryForConnection(firstWs, GatewayPresenceStatus.Online); - if (entry) { - entries.push(entry); + closeWithError({ code, message }: { code: number, message: string }) { + this.ws.close(code, message); + } + + closeWithBadPayload(hint: string) { + this.ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`); + } + + getPresenceEntry(status: GatewayPresenceStatus): GatewayPresenceEntry | null { + if (!this.user || !this.ready) { + return null; + } + + return { + user: { + id: this.user.id, + username: this.user.username, + avatar: this.user.avatar + }, + status, + bridgesTo: this.bridgesTo, + privacy: this.privacy, + terms: this.terms, + }; + } + + handleClose() { + gatewayClients.delete(this); + this.unsubscribeAll(); + this.ready = false; + + if (this.user) { + const sessions = sessionsByUserId.get(this.user.id); + if (sessions) { + const index = sessions.indexOf(this); + sessions.splice(index, 1); + if (!sessions.length) { + sessionsByUserId.delete(this.user.id); + + // user no longer has any sessions, update presence + dispatch("*", { + t: GatewayPayloadType.PresenceUpdate, + d: [ + this.getPresenceEntry(GatewayPresenceStatus.Offline) + ] + }); + } } } - }); + } - return entries; + batchTick() { + const now = performance.now(); + if ((now - this.lastAliveCheck) >= GATEWAY_PING_INTERVAL) { + if (!this.ready) { + return this.closeWithError(gatewayErrors.AUTHENTICATION_TIMEOUT); + } + if (!this.alive) { + return this.closeWithError(gatewayErrors.NO_PING); + } + this.messagesSinceLastCheck = 0; + } + } + + async handleMessage(rawData: Buffer, isBinary: boolean) { + this.messagesSinceLastCheck++; + + if (rawData.byteLength >= maxGatewayPayloadByteLength) { + return this.closeWithError(gatewayErrors.PAYLOAD_TOO_LARGE); + } + + if (this.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) { + return this.closeWithError(gatewayErrors.FLOODING); + } + + + let stringData: string; + let binaryStream: Buffer | null = null; + if (isBinary) { + // Binary frames are used in order combine our text data (JSON) with binary data. + // This is especially useful for calling RPC methods that, for example, upload files. + // The format is: [json payload]\n[begin binary stream] + + let jsonSlice; + let jsonOffset = -1; + for (let i = 0; i < maxGatewayJsonStringByteLength; i++) { + if (rawData.readUInt8(i) === 0x0A) { + // hit newline + jsonSlice = rawData.subarray(0, i); + jsonOffset = i + 1; + break; + } + } + if (!jsonSlice) { + return this.closeWithBadPayload("Did not find newline to delimit JSON from binary stream. JSON payload may be too large, or newline may be missing."); + } + + binaryStream = rawData.subarray(jsonOffset, rawData.byteLength); + stringData = jsonSlice.toString(); + } else { + stringData = rawData.toString(); + } + + if (stringData.length > maxGatewayJsonStringLength) { + return this.closeWithError(gatewayErrors.PAYLOAD_TOO_LARGE); + } + + const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData)); + if (!payload) { + return this.closeWithBadPayload("Invalid JSON or message does not match schema"); + } + + switch (payload.t) { + case GatewayPayloadType.Authenticate: { + if (this.ready) { + return this.closeWithError(gatewayErrors.ALREADY_AUTHENTICATED); + } + + const authData = payload.d; + if (typeof authData !== "object") { + return this.closeWithBadPayload("d: expected object"); + } + + if (typeof authData.token !== "string") { + return this.closeWithBadPayload("d: invalid field 'token'"); + } + + const user = await decodeTokenOrNull(authData.token); + if (!user) { + return this.closeWithError(gatewayErrors.BAD_AUTH); + } + + let sessions = sessionsByUserId.get(user.id); + if (sessions) { + if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) { + return this.closeWithError(gatewayErrors.TOO_MANY_SESSIONS); + } + } + + // TODO: each user should have their own list of channels that they join + const [channels, communities] = await Promise.all([ + query("SELECT id, name, owner_id, community_id FROM channels ORDER BY id ASC"), + query("SELECT id, name, owner_id, avatar, created_at FROM communities ORDER BY id ASC"), + ]); + if (!channels || !communities) { + return this.closeWithError(gatewayErrors.GOT_NO_DATABASE_DATA); + } + + if (!sessions) { + sessions = []; + sessionsByUserId.set(user.id, sessions); + } + sessions.push(this); + + this.subscribe("*"); + for (let i = 0; i < channels.rows.length; i++) { + this.subscribe(`channel:${channels.rows[i].id}`); + } + for (let i = 0; i < communities.rows.length; i++) { + this.subscribe(`community:${communities.rows[i].id}`); + } + + this.user = user; + + // first session, notify others that we are online + if (sessions.length === 1) { + dispatch("*", { + t: GatewayPayloadType.PresenceUpdate, + d: [this.getPresenceEntry(GatewayPresenceStatus.Online)] + }); + } + + this.ready = true; + + this.send({ + t: GatewayPayloadType.Ready, + d: { + user: getPublicUserObject(this.user), + channels: channels.rows, + communities: communities.rows, + presence: getInitialPresenceEntries() + } + }); + break; + } + case GatewayPayloadType.Ping: { + if (payload.d !== 0) { + return this.closeWithBadPayload("d: expected numeric '0'"); + } + + if (!this.ready) { + return this.closeWithError(gatewayErrors.NOT_AUTHENTICATED); + } + + this.alive = true; + break; + } + case GatewayPayloadType.RPCSignal: /* through */ + case GatewayPayloadType.RPCRequest: { + if (!this.ready || !this.user) { + return this.closeWithError(gatewayErrors.NOT_AUTHENTICATED); + } + + // RPCSignal is like RPCRequest however it does not send RPC method output unless there is an error + processMethodBatch(this.user, payload.d, (payload.t === GatewayPayloadType.RPCSignal ? true : false), binaryStream).then((results) => { + this.send({ + t: GatewayPayloadType.RPCResponse, + d: results, + s: payload.s + }); + }); + break; + } + default: { + return this.closeWithBadPayload("t: unknown type"); + } + } + } } export default function(server: Server) { const wss = new WebSocketServer({ server }); const batchInterval = setInterval(() => { - wss.clients.forEach((e) => { - const now = performance.now(); - if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) { - if (!e.state.ready) { - return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT); - } - if (!e.state.alive) { - return closeWithError(e, gatewayErrors.NO_PING); - } - e.state.messagesSinceLastCheck = 0; - } - }); + gatewayClients.forEach(client => client.batchTick()); }, GATEWAY_BATCH_INTERVAL); wss.on("close", () => { - console.warn("gateway: websocket server closed"); - console.warn("gateway: clearing batch interval due to websocket server close"); + console.error("gateway: websocket server closed"); clearInterval(batchInterval); }); wss.on("connection", (ws) => { - ws.state = { - user: undefined, - alive: false, - ready: false, - lastAliveCheck: performance.now(), - dispatchChannels: new Set(), - messagesSinceLastCheck: 0, - bridgesTo: undefined - }; - - sendPayload(ws, { - t: GatewayPayloadType.Hello, - d: { - pingInterval: GATEWAY_PING_INTERVAL - } - }); - - ws.on("close", () => { - clientUnsubscribeAll(ws); - ws.state.ready = false; - if (ws.state.user && ws.state.user.id) { - const sessions = sessionsByUserId.get(ws.state.user.id); - if (sessions) { - const index = sessions.indexOf(ws); - sessions.splice(index, 1); - if (sessions.length < 1) { - sessionsByUserId.delete(ws.state.user.id); - - // user no longer has any sessions, update presence - dispatch("*", { - t: GatewayPayloadType.PresenceUpdate, - d: [getPresenceEntryForConnection(ws, GatewayPresenceStatus.Offline)] - }); - } - } - } - }); - - ws.on("message", async (rawData: Buffer, isBinary) => { - ws.state.messagesSinceLastCheck++; - if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) { - return closeWithError(ws, gatewayErrors.FLOODING); - } - - if (rawData.byteLength >= maxGatewayPayloadByteLength) { - return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE); - } - - let stringData: string; - let binaryStream: Buffer | null = null; - if (isBinary) { - // Binary frames are used in order combine our text data (JSON) with binary data. - // This is especially useful for calling RPC methods that, for example, upload files. - // The format is: [json payload]\n[begin binary stream] - - let jsonSlice; - let jsonOffset = -1; - for (let i = 0; i < maxGatewayJsonStringByteLength; i++) { - if (rawData.readUInt8(i) === 0x0A) { - // hit newline - jsonSlice = rawData.subarray(0, i); - jsonOffset = i + 1; - break; - } - } - if (!jsonSlice) { - return closeWithBadPayload(ws, "Did not find newline to delimit JSON from binary stream. JSON payload may be too large, or newline may be missing."); - } - - binaryStream = rawData.subarray(jsonOffset, rawData.byteLength); - stringData = jsonSlice.toString(); - } else { - stringData = rawData.toString(); - } - - if (stringData.length > maxGatewayJsonStringLength) { - return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE); - } - - const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData)); - if (!payload) { - return closeWithBadPayload(ws, "Invalid JSON or message does not match schema"); - } - - switch (payload.t) { - case GatewayPayloadType.Authenticate: { - if (ws.state.ready) { - return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED); - } - - const authData = payload.d; - if (typeof authData !== "object") { - return closeWithBadPayload(ws, "d: expected object"); - } - - if (typeof authData.token !== "string") { - return closeWithBadPayload(ws, "d: invalid field 'token'"); - } - - if (typeof authData.bridgesTo !== "undefined" && typeof authData.bridgesTo !== "string" && authData.bridgesTo.length > 40) { - return closeWithBadPayload(ws, "d: invalid field 'bridgesTo'"); - } - - if (typeof authData.privacy !== "undefined" && typeof authData.privacy !== "string" && authData.privacy.length > 200) { - return closeWithBadPayload(ws, "d: invalid field 'privacy'"); - } - - if (typeof authData.terms !== "undefined" && typeof authData.terms !== "string" && authData.terms.length > 200) { - return closeWithBadPayload(ws, "d: invalid field 'terms'"); - } - - const user = await decodeTokenOrNull(authData.token); - if (!user) { - return closeWithError(ws, gatewayErrors.BAD_AUTH); - } - - let sessions = sessionsByUserId.get(user.id); - if (sessions) { - if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) { - return closeWithError(ws, gatewayErrors.TOO_MANY_SESSIONS); - } - } - - // TODO: each user should have their own list of channels that they join - const [channels, communities] = await Promise.all([ - query("SELECT id, name, owner_id, community_id FROM channels ORDER BY id ASC"), - query("SELECT id, name, owner_id, avatar, created_at FROM communities ORDER BY id ASC"), - ]); - if (!channels || !communities) { - return closeWithError(ws, gatewayErrors.GOT_NO_DATABASE_DATA); - } - - if (!sessions) { - sessions = []; - sessionsByUserId.set(user.id, sessions); - } - sessions.push(ws); - - clientSubscribe(ws, "*"); - for (let i = 0; i < channels.rows.length; i++) { - clientSubscribe(ws, `channel:${channels.rows[i].id}`); - } - for (let i = 0; i < communities.rows.length; i++) { - clientSubscribe(ws, `community:${communities.rows[i].id}`); - } - - ws.state.user = user; - ws.state.bridgesTo = authData.bridgesTo; - ws.state.privacy = authData.privacy; - ws.state.terms = authData.terms; - - // first session, notify others that we are online - if (sessions.length === 1) { - dispatch("*", { - t: GatewayPayloadType.PresenceUpdate, - d: [getPresenceEntryForConnection(ws, GatewayPresenceStatus.Online)] - }); - } - - ws.state.ready = true; - - sendPayload(ws, { - t: GatewayPayloadType.Ready, - d: { - user: getPublicUserObject(ws.state.user), - channels: channels.rows, - communities: communities.rows, - presence: getInitialPresenceEntries() - } - }); - break; - } - case GatewayPayloadType.Ping: { - if (payload.d !== 0) { - return closeWithBadPayload(ws, "d: expected numeric '0'"); - } - - if (!ws.state.ready) { - return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED); - } - - // TODO: also check session here - ws.state.alive = true; - break; - } - case GatewayPayloadType.RPCSignal: /* through */ - case GatewayPayloadType.RPCRequest: { - if (!ws.state.ready || !ws.state.user) { - return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED); - } - - // RPCSignal is like RPCRequest however it does not send RPC method output unless there is an error - processMethodBatch(ws.state.user, payload.d, (payload.t === GatewayPayloadType.RPCSignal ? true : false), binaryStream).then((results) => { - sendPayload(ws, { - t: GatewayPayloadType.RPCResponse, - d: results, - s: payload.s - }); - }); - break; - } - default: { - return closeWithBadPayload(ws, "t: unknown type"); - } - } - }); + const client = new GatewayClient(ws); + client.greet(); }); }; diff --git a/src/types/gatewayclientstate.d.ts b/src/types/gatewayclientstate.d.ts deleted file mode 100644 index 4fc8d76..0000000 --- a/src/types/gatewayclientstate.d.ts +++ /dev/null @@ -1,11 +0,0 @@ -interface GatewayClientState { - user?: User; - ready: boolean, - alive: boolean, - lastAliveCheck: number, - dispatchChannels: Set, - messagesSinceLastCheck: number, - bridgesTo?: string - privacy?: string, - terms?: string, -} diff --git a/src/types/ws.d.ts b/src/types/ws.d.ts deleted file mode 100644 index cb08647..0000000 --- a/src/types/ws.d.ts +++ /dev/null @@ -1,7 +0,0 @@ -import ws from 'ws'; - -declare module 'ws' { - export interface WebSocket extends ws { - state: GatewayClientState; - } -}