import { Server } from "node:http"; import { performance } from "node:perf_hooks"; import { WebSocketServer, WebSocket } from "ws"; import { decodeTokenOrNull, getPublicUserObject } from "../auth"; import { query } from "../database"; import { gatewayErrors } from "../errors"; import { GatewayPayload } from "../types/gatewaypayload"; import { GatewayPayloadType } from "./gatewaypayloadtype"; const GATEWAY_BATCH_INTERVAL = 50000; const GATEWAY_PING_INTERVAL = 40000; const MAX_CLIENT_MESSAGES_PER_BATCH = 6; // TODO: how well does this work for weak connections? // mapping between a dispatch id and a websocket client const dispatchChannels = new Map>(); function clientSubscribe(ws: WebSocket, dispatchChannel: string) { ws.state.dispatchChannels.add(dispatchChannel); if (!dispatchChannels.get(dispatchChannel)) { dispatchChannels.set(dispatchChannel, new Set()); } dispatchChannels.get(dispatchChannel)?.add(ws); } function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) { if (!ws.state) return; ws.state.dispatchChannels.delete(dispatchChannel); const set = dispatchChannels.get(dispatchChannel); if (!set) return; set.delete(ws); if (set.size < 1) { dispatchChannels.delete(dispatchChannel); } } export function dispatchChannelSubscribe(target: string, dispatchChannel: string) { const set = dispatchChannels.get(target); if (!set) return; set.forEach(c => { clientSubscribe(c, dispatchChannel); }); } function clientUnsubscribeAll(ws: WebSocket) { if (!ws.state) return; ws.state.dispatchChannels.forEach(e => { const set = dispatchChannels.get(e); if (!set) return; set.delete(ws); if (set && set.size < 1) { dispatchChannels.delete(e); } }); ws.state.dispatchChannels = new Set(); } export function dispatch(channel: string, message: GatewayPayload) { const members = dispatchChannels.get(channel); if (!members) return; members.forEach(e => { e.send(JSON.stringify(message)); }); } function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { return ws.close(code, message); } function closeWithBadPayload(ws: WebSocket, hint: string) { return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`); } function parseJsonOrNull(payload: string): any { try { return JSON.parse(payload); } catch (e) { return null; } } // The function below ensures `payload` is of the GatewayPayload // interface payload. If it does not match, null is returned. function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null { if (!payload) { return null; } let foundT = false; let foundD = false; for (const [k, v] of Object.entries(payload)) { if (k === "t" && typeof v === "number") { foundT = true; } else if (k === "d") { foundD = true; } else { return null; } } if (!foundT || !foundD) { return null; } const asPayload = payload as GatewayPayload; return asPayload; } function sendPayload(ws: WebSocket, payload: GatewayPayload) { ws.send(JSON.stringify(payload)); } export default function(server: Server) { const wss = new WebSocketServer({ server }); const batchInterval = setInterval(() => { wss.clients.forEach((e) => { const now = performance.now(); if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) { if (!e.state.ready) { return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT); } if (!e.state.alive) { return closeWithError(e, gatewayErrors.NO_PING); } e.state.messagesSinceLastCheck = 0; } }); }, GATEWAY_BATCH_INTERVAL); wss.on("close", () => { console.warn("gateway: websocket server closed"); console.warn("gateway: clearing batch interval due to websocket server close"); clearInterval(batchInterval); }); wss.on("connection", (ws) => { ws.state = { user: undefined, alive: false, ready: false, lastAliveCheck: performance.now(), dispatchChannels: new Set(), messagesSinceLastCheck: 0 }; sendPayload(ws, { t: GatewayPayloadType.Hello, d: { pingInterval: GATEWAY_PING_INTERVAL } }); ws.on("close", () => { clientUnsubscribeAll(ws); }); ws.on("message", async (rawData, isBinary) => { if (isBinary) { return closeWithBadPayload(ws, "Binary messages are not supported"); } ws.state.messagesSinceLastCheck++; if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) { return closeWithError(ws, gatewayErrors.FLOODING); } const payload = ensureFormattedGatewayPayload(parseJsonOrNull(rawData.toString())); if (!payload) { return closeWithBadPayload(ws, "Invalid JSON or message does not match schema"); } switch (payload.t) { case GatewayPayloadType.Authenticate: { if (ws.state.ready) { return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED); } const token = payload.d; if (typeof token !== "string") { return closeWithBadPayload(ws, "d: expected string"); } const user = await decodeTokenOrNull(token); if (!user) { return closeWithError(ws, gatewayErrors.BAD_AUTH); } // each user should have their own list of channels that they join const channels = await query("SELECT id, name, owner_id FROM channels"); clientSubscribe(ws, "*"); channels.rows.forEach(c => { clientSubscribe(ws, `channel:${c.id}`); }); ws.state.user = user; ws.state.ready = true; sendPayload(ws, { t: GatewayPayloadType.Ready, d: { user: getPublicUserObject(ws.state.user), channels: channels.rows } }) break; } case GatewayPayloadType.Ping: { // TODO: also check session here and ensure packet is sent at the right time ws.state.alive = true; break; } default: { return closeWithBadPayload(ws, "t: unknown type"); } } }); }); };