import { Server } from "node:http"; import { performance } from "node:perf_hooks"; import { WebSocketServer, WebSocket } from "ws"; import { decodeTokenOrNull, getPublicUserObject } from "../auth"; import { query } from "../database"; import { gatewayErrors } from "../errors"; import { GatewayPayload } from "../types/gatewaypayload"; import { GatewayPayloadType, GatewayPresenceStatus } from "./gatewaypayloadtype"; import { GatewayPresenceEntry } from "./gatewaypresence"; const GATEWAY_BATCH_INTERVAL = 50000; const GATEWAY_PING_INTERVAL = 40000; const MAX_CLIENT_MESSAGES_PER_BATCH = 6; // TODO: how well does this work for weak connections? const MAX_GATEWAY_SESSIONS_PER_USER = 5; // mapping between a dispatch id and a websocket client const dispatchChannels = new Map>(); // mapping between a user id and the websocket sessions it has const sessionsByUserId = new Map(); function clientSubscribe(ws: WebSocket, dispatchChannel: string) { ws.state.dispatchChannels.add(dispatchChannel); if (!dispatchChannels.get(dispatchChannel)) { dispatchChannels.set(dispatchChannel, new Set()); } dispatchChannels.get(dispatchChannel)?.add(ws); } function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) { if (!ws.state) return; ws.state.dispatchChannels.delete(dispatchChannel); const set = dispatchChannels.get(dispatchChannel); if (!set) return; set.delete(ws); if (set.size < 1) { dispatchChannels.delete(dispatchChannel); } } export function dispatchChannelSubscribe(target: string, dispatchChannel: string) { const set = dispatchChannels.get(target); if (!set) return; set.forEach(c => { clientSubscribe(c, dispatchChannel); }); } function clientUnsubscribeAll(ws: WebSocket) { if (!ws.state) return; ws.state.dispatchChannels.forEach(e => { const set = dispatchChannels.get(e); if (!set) return; set.delete(ws); if (set && set.size < 1) { dispatchChannels.delete(e); } }); ws.state.dispatchChannels = new Set(); } export function dispatch(channel: string, message: GatewayPayload) { const members = dispatchChannels.get(channel); if (!members) return; members.forEach(e => { if (e.state.ready) { e.send(JSON.stringify(message)); } }); } function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { return ws.close(code, message); } function closeWithBadPayload(ws: WebSocket, hint: string) { return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`); } function parseJsonOrNull(payload: string): any { try { return JSON.parse(payload); } catch (e) { return null; } } // The function below ensures `payload` is of the GatewayPayload // interface payload. If it does not match, null is returned. function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null { if (!payload) { return null; } let foundT = false; let foundD = false; for (const [k, v] of Object.entries(payload)) { if (k === "t" && typeof v === "number") { foundT = true; } else if (k === "d") { foundD = true; } else { return null; } } if (!foundT || !foundD) { return null; } const asPayload = payload as GatewayPayload; return asPayload; } function sendPayload(ws: WebSocket, payload: GatewayPayload) { ws.send(JSON.stringify(payload)); } function getPresenceEntryForUser(user: User, status: GatewayPresenceStatus): GatewayPresenceEntry { return { user: { id: user.id, username: user.username }, status } } // The initial presence entries are sent right when the user connects. // In the future, each user will have their own list of channels that they can join and leave. // In that case, we will send the presence entries to a certain user only for the channels they're in. function getInitialPresenceEntries(): GatewayPresenceEntry[] { const entries: GatewayPresenceEntry[] = []; sessionsByUserId.forEach((wsList: WebSocket[], userId: number) => { if (wsList.length < 1) return; const firstWs = wsList[0]; if (firstWs.state.ready && firstWs.state.user) { entries.push(getPresenceEntryForUser(firstWs.state.user, GatewayPresenceStatus.Online)); } }); return entries; } export default function(server: Server) { const wss = new WebSocketServer({ server }); const batchInterval = setInterval(() => { wss.clients.forEach((e) => { const now = performance.now(); if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) { if (!e.state.ready) { return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT); } if (!e.state.alive) { return closeWithError(e, gatewayErrors.NO_PING); } e.state.messagesSinceLastCheck = 0; } }); }, GATEWAY_BATCH_INTERVAL); wss.on("close", () => { console.warn("gateway: websocket server closed"); console.warn("gateway: clearing batch interval due to websocket server close"); clearInterval(batchInterval); }); wss.on("connection", (ws) => { ws.state = { user: undefined, alive: false, ready: false, lastAliveCheck: performance.now(), dispatchChannels: new Set(), messagesSinceLastCheck: 0 }; sendPayload(ws, { t: GatewayPayloadType.Hello, d: { pingInterval: GATEWAY_PING_INTERVAL } }); ws.on("close", () => { clientUnsubscribeAll(ws); ws.state.ready = false; if (ws.state.user && ws.state.user.id) { const sessions = sessionsByUserId.get(ws.state.user.id); if (sessions) { const index = sessions.indexOf(ws); sessions.splice(index, 1); if (sessions.length < 1) { sessionsByUserId.delete(ws.state.user.id); // user no longer has any sessions, update presence dispatch("*", { t: GatewayPayloadType.PresenceUpdate, d: [getPresenceEntryForUser(ws.state.user, GatewayPresenceStatus.Offline)] }); } } } }); ws.on("message", async (rawData, isBinary) => { if (isBinary) { return closeWithBadPayload(ws, "Binary messages are not supported"); } ws.state.messagesSinceLastCheck++; if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) { return closeWithError(ws, gatewayErrors.FLOODING); } const stringData = rawData.toString(); if (stringData.length > 2048) { return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE); } const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData)); if (!payload) { return closeWithBadPayload(ws, "Invalid JSON or message does not match schema"); } switch (payload.t) { case GatewayPayloadType.Authenticate: { if (ws.state.ready) { return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED); } const token = payload.d; if (typeof token !== "string") { return closeWithBadPayload(ws, "d: expected string"); } const user = await decodeTokenOrNull(token); if (!user) { return closeWithError(ws, gatewayErrors.BAD_AUTH); } let sessions = sessionsByUserId.get(user.id); if (sessions) { if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) { return closeWithError(ws, gatewayErrors.TOO_MANY_SESSIONS); } } else { sessions = []; sessionsByUserId.set(user.id, sessions); } sessions.push(ws); // TODO: each user should have their own list of channels that they join const channels = await query("SELECT id, name, owner_id FROM channels ORDER BY id ASC"); if (!channels) { return closeWithError(ws, gatewayErrors.GOT_NO_DATABASE_DATA); } clientSubscribe(ws, "*"); channels.rows.forEach(c => { clientSubscribe(ws, `channel:${c.id}`); }); ws.state.user = user; // first session, notify others that we are online if (sessions.length === 1) { dispatch("*", { t: GatewayPayloadType.PresenceUpdate, d: [getPresenceEntryForUser(ws.state.user, GatewayPresenceStatus.Online)] }); } ws.state.ready = true; sendPayload(ws, { t: GatewayPayloadType.Ready, d: { user: getPublicUserObject(ws.state.user), channels: channels.rows, presence: getInitialPresenceEntries() } }); break; } case GatewayPayloadType.Ping: { if (payload.d !== 0) { return closeWithBadPayload(ws, "d: expected numeric '0'"); } if (!ws.state.ready) { return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED); } // TODO: also check session here ws.state.alive = true; break; } default: { return closeWithBadPayload(ws, "t: unknown type"); } } }); }); };