diff --git a/src/auth.ts b/src/auth.ts index 3ba9414..a8cace2 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -15,6 +15,13 @@ if (jwtSecret === "[generic token]") { process.exit(1); } +export function getPublicUserObject(user: User) { + const newUser = { ...user }; + newUser.password = undefined; + delete newUser.password; + + return newUser; +} export function signToken(userId: number) { return new Promise((resolve, reject) => { @@ -99,10 +106,7 @@ export function authenticateRoute() { } req.user = user; - req.publicUser = {...user}; - if (req.publicUser.password) { - delete req.publicUser.password; - } + req.publicUser = getPublicUserObject(user); next(); }; diff --git a/src/gateway/gatewaypayloadtype.ts b/src/gateway/gatewaypayloadtype.ts index 4b1bee9..6c848c3 100644 --- a/src/gateway/gatewaypayloadtype.ts +++ b/src/gateway/gatewaypayloadtype.ts @@ -1,5 +1,6 @@ export enum GatewayPayloadType { Hello = 0, Authenticate, - Ready + Ready, + Ping } diff --git a/src/gateway/index.ts b/src/gateway/index.ts index 2882d4d..76a2b9f 100644 --- a/src/gateway/index.ts +++ b/src/gateway/index.ts @@ -1,7 +1,8 @@ import { Server } from "node:http"; import { performance } from "node:perf_hooks"; -import WebSocket, { WebSocketServer } from "ws"; -import { decodeTokenOrNull } from "../auth"; +import { WebSocketServer, WebSocket } from "ws"; +import { decodeTokenOrNull, getPublicUserObject } from "../auth"; +import { query } from "../database"; import { gatewayErrors } from "../errors"; import { GatewayPayload } from "../types/gatewaypayload"; import { GatewayPayloadType } from "./gatewaypayloadtype"; @@ -9,8 +10,57 @@ import { GatewayPayloadType } from "./gatewaypayloadtype"; const GATEWAY_BATCH_INTERVAL = 25000 || process.env.GATEWAY_BATCH_INTERVAL; const GATEWAY_PING_INTERVAL = 20000 || process.env.GATEWAY_PING_INTERVAL; +// mapping between a broadcast id and a websocket client +const broadcastChannels = new Map>(); + +function clientSubscribe(ws: WebSocket, broadcastChannel: string) { + ws.state.broadcastChannels.add(broadcastChannel); + if (!broadcastChannels.get(broadcastChannel)) { + broadcastChannels.set(broadcastChannel, new Set()); + } + + broadcastChannels.get(broadcastChannel)?.add(ws); +} + +function clientUnsubscribe(ws: WebSocket, broadcastChannel: string) { + if (!ws.state) return; + + ws.state.broadcastChannels.delete(broadcastChannel); + + const set = broadcastChannels.get(broadcastChannel); + if (!set) return; + + set.delete(ws); + if (set.size < 1) { + broadcastChannels.delete(broadcastChannel); + } +} + +function clientUnsubscribeAll(ws: WebSocket) { + if (!ws.state) return; + + ws.state.broadcastChannels.forEach(e => { + const set = broadcastChannels.get(e); + if (!set) return; + + set.delete(ws); + if (set && set.size < 1) { + broadcastChannels.delete(e); + } + }); + + ws.state.broadcastChannels = new Set(); +} + +export function broadcast(channel: string, message: GatewayPayload) { + const members = broadcastChannels.get(channel); + if (!members) return; + + members.forEach(e => e.send(JSON.stringify(message))); +} + function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { - return ws.close(1000, `(${code}) ${message}`); + return ws.close(code, message); } function closeWithBadPayload(ws: WebSocket, hint: string) { @@ -57,23 +107,33 @@ function sendPayload(ws: WebSocket, payload: GatewayPayload) { export default function(server: Server) { const wss = new WebSocketServer({ server }); - setInterval(() => { + const batchInterval = setInterval(() => { wss.clients.forEach((e) => { const now = performance.now(); if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) { if (!e.state.ready) { return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT); } + if (!e.state.alive) { + return closeWithError(e, gatewayErrors.NO_PING); + } } }); }, GATEWAY_BATCH_INTERVAL); + wss.on("close", () => { + console.warn("gateway: websocket server closed"); + console.warn("gateway: clearing batch interval due to websocket server close"); + clearInterval(batchInterval); + }); + wss.on("connection", (ws) => { ws.state = { user: undefined, alive: false, ready: false, - lastAliveCheck: performance.now() + lastAliveCheck: performance.now(), + broadcastChannels: new Set() }; sendPayload(ws, { @@ -83,6 +143,11 @@ export default function(server: Server) { } }); + ws.on("close", () => { + clientUnsubscribeAll(ws); + console.log(broadcastChannels); + }); + ws.on("message", async (rawData, isBinary) => { if (isBinary) { return closeWithBadPayload(ws, "Binary messages are not supported"); @@ -103,17 +168,30 @@ export default function(server: Server) { if (!user) { return closeWithError(ws, gatewayErrors.BAD_AUTH); } + // each user should have their own list of channels that they join + const channels = await query("SELECT id, name, owner_id FROM channels"); + + channels.rows.forEach(c => { + clientSubscribe(ws, `channel:${c.id}`); + }); + ws.state.user = user; ws.state.ready = true; - + sendPayload(ws, { t: GatewayPayloadType.Ready, d: { - user: ws.state.user, + user: getPublicUserObject(ws.state.user), + channels: channels.rows } }) break; } + case GatewayPayloadType.Ping: { + // TODO: also check session here and ensure packet is sent at the right time + ws.state.alive = true; + break; + } default: { return closeWithBadPayload(ws, "t: unknown type"); } diff --git a/src/types/gatewayclientstate.d.ts b/src/types/gatewayclientstate.d.ts index cf44b3a..72d10c8 100644 --- a/src/types/gatewayclientstate.d.ts +++ b/src/types/gatewayclientstate.d.ts @@ -3,4 +3,5 @@ interface GatewayClientState { ready: boolean, alive: boolean, lastAliveCheck: number, + broadcastChannels: Set }