import { Server } from "node:http"; import { performance } from "node:perf_hooks"; import { WebSocketServer, WebSocket } from "ws"; import { decodeTokenOrNull, getPublicUserObject } from "../auth"; import { query } from "../database"; import { gatewayErrors } from "../errors"; import { GatewayPayload } from "../types/gatewaypayload"; import { GatewayPayloadType } from "./gatewaypayloadtype"; const GATEWAY_BATCH_INTERVAL = 25000 || process.env.GATEWAY_BATCH_INTERVAL; const GATEWAY_PING_INTERVAL = 20000 || process.env.GATEWAY_PING_INTERVAL; // mapping between a broadcast id and a websocket client const broadcastChannels = new Map>(); function clientSubscribe(ws: WebSocket, broadcastChannel: string) { ws.state.broadcastChannels.add(broadcastChannel); if (!broadcastChannels.get(broadcastChannel)) { broadcastChannels.set(broadcastChannel, new Set()); } broadcastChannels.get(broadcastChannel)?.add(ws); } function clientUnsubscribe(ws: WebSocket, broadcastChannel: string) { if (!ws.state) return; ws.state.broadcastChannels.delete(broadcastChannel); const set = broadcastChannels.get(broadcastChannel); if (!set) return; set.delete(ws); if (set.size < 1) { broadcastChannels.delete(broadcastChannel); } } function clientUnsubscribeAll(ws: WebSocket) { if (!ws.state) return; ws.state.broadcastChannels.forEach(e => { const set = broadcastChannels.get(e); if (!set) return; set.delete(ws); if (set && set.size < 1) { broadcastChannels.delete(e); } }); ws.state.broadcastChannels = new Set(); } export function broadcast(channel: string, message: GatewayPayload) { const members = broadcastChannels.get(channel); if (!members) return; members.forEach(e => e.send(JSON.stringify(message))); } function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { return ws.close(code, message); } function closeWithBadPayload(ws: WebSocket, hint: string) { return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`); } function parseJsonOrNull(payload: string): any { try { return JSON.parse(payload); } catch (e) { return null; } } // The function below ensures `payload` is of the GatewayPayload // interface payload. If it does not match, null is returned. function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null { if (!payload) { return null; } let foundT = false; let foundD = false; for (const [k, v] of Object.entries(payload)) { if (k === "t" && typeof v === "number") { foundT = true; } else if (k === "d") { foundD = true; } else { return null; } } if (!foundT || !foundD) { return null; } const asPayload = payload as GatewayPayload; return asPayload; } function sendPayload(ws: WebSocket, payload: GatewayPayload) { ws.send(JSON.stringify(payload)); } export default function(server: Server) { const wss = new WebSocketServer({ server }); const batchInterval = setInterval(() => { wss.clients.forEach((e) => { const now = performance.now(); if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) { if (!e.state.ready) { return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT); } if (!e.state.alive) { return closeWithError(e, gatewayErrors.NO_PING); } } }); }, GATEWAY_BATCH_INTERVAL); wss.on("close", () => { console.warn("gateway: websocket server closed"); console.warn("gateway: clearing batch interval due to websocket server close"); clearInterval(batchInterval); }); wss.on("connection", (ws) => { ws.state = { user: undefined, alive: false, ready: false, lastAliveCheck: performance.now(), broadcastChannels: new Set() }; sendPayload(ws, { t: GatewayPayloadType.Hello, d: { pingInterval: GATEWAY_PING_INTERVAL } }); ws.on("close", () => { clientUnsubscribeAll(ws); console.log(broadcastChannels); }); ws.on("message", async (rawData, isBinary) => { if (isBinary) { return closeWithBadPayload(ws, "Binary messages are not supported"); } const payload = ensureFormattedGatewayPayload(parseJsonOrNull(rawData.toString())); if (!payload) { return closeWithBadPayload(ws, "Invalid JSON or message does not match schema"); } switch (payload.t) { case GatewayPayloadType.Authenticate: { const token = payload.d; if (typeof token !== "string") { return closeWithBadPayload(ws, "d: expected string"); } const user = await decodeTokenOrNull(token); if (!user) { return closeWithError(ws, gatewayErrors.BAD_AUTH); } // each user should have their own list of channels that they join const channels = await query("SELECT id, name, owner_id FROM channels"); channels.rows.forEach(c => { clientSubscribe(ws, `channel:${c.id}`); }); ws.state.user = user; ws.state.ready = true; sendPayload(ws, { t: GatewayPayloadType.Ready, d: { user: getPublicUserObject(ws.state.user), channels: channels.rows } }) break; } case GatewayPayloadType.Ping: { // TODO: also check session here and ensure packet is sent at the right time ws.state.alive = true; break; } default: { return closeWithBadPayload(ws, "t: unknown type"); } } }); }); };