waffle/src/gateway/index.ts

320 lines
11 KiB
TypeScript

import { Server } from "node:http";
import { performance } from "node:perf_hooks";
import { WebSocketServer, WebSocket } from "ws";
import { decodeTokenOrNull, getPublicUserObject } from "../auth";
import { query } from "../database";
import { gatewayErrors } from "../errors";
import { GatewayPayload } from "../types/gatewaypayload";
import { GatewayPayloadType, GatewayPresenceStatus } from "./gatewaypayloadtype";
import { GatewayPresenceEntry } from "./gatewaypresence";
const GATEWAY_BATCH_INTERVAL = 50000;
const GATEWAY_PING_INTERVAL = 40000;
const MAX_CLIENT_MESSAGES_PER_BATCH = 6; // TODO: how well does this work for weak connections?
const MAX_GATEWAY_SESSIONS_PER_USER = 5;
// mapping between a dispatch id and a websocket client
const dispatchChannels = new Map<string, Set<WebSocket>>();
// mapping between a user id and the websocket sessions it has
const sessionsByUserId = new Map<number, WebSocket[]>();
function clientSubscribe(ws: WebSocket, dispatchChannel: string) {
ws.state.dispatchChannels.add(dispatchChannel);
if (!dispatchChannels.get(dispatchChannel)) {
dispatchChannels.set(dispatchChannel, new Set());
}
dispatchChannels.get(dispatchChannel)?.add(ws);
}
function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) {
if (!ws.state) return;
ws.state.dispatchChannels.delete(dispatchChannel);
const set = dispatchChannels.get(dispatchChannel);
if (!set) return;
set.delete(ws);
if (set.size < 1) {
dispatchChannels.delete(dispatchChannel);
}
}
export function dispatchChannelSubscribe(target: string, dispatchChannel: string) {
const set = dispatchChannels.get(target);
if (!set) return;
set.forEach(c => {
clientSubscribe(c, dispatchChannel);
});
}
function clientUnsubscribeAll(ws: WebSocket) {
if (!ws.state) return;
ws.state.dispatchChannels.forEach(e => {
const set = dispatchChannels.get(e);
if (!set) return;
set.delete(ws);
if (set && set.size < 1) {
dispatchChannels.delete(e);
}
});
ws.state.dispatchChannels = new Set();
}
export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket) => GatewayPayload)) {
const members = dispatchChannels.get(channel);
if (!members) return;
members.forEach(e => {
if (e.state.ready) {
let data = message;
if (typeof message === "function") {
data = message(e);
}
e.send(JSON.stringify(data));
}
});
}
function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) {
return ws.close(code, message);
}
function closeWithBadPayload(ws: WebSocket, hint: string) {
return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`);
}
function parseJsonOrNull(payload: string): any {
try {
return JSON.parse(payload);
} catch (e) {
return null;
}
}
// The function below ensures `payload` is of the GatewayPayload
// interface payload. If it does not match, null is returned.
function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null {
if (!payload) {
return null;
}
let foundT = false;
let foundD = false;
for (const [k, v] of Object.entries(payload)) {
if (k === "t" && typeof v === "number") {
foundT = true;
} else if (k === "d") {
foundD = true;
} else {
return null;
}
}
if (!foundT || !foundD) {
return null;
}
const asPayload = payload as GatewayPayload;
return asPayload;
}
function sendPayload(ws: WebSocket, payload: GatewayPayload) {
ws.send(JSON.stringify(payload));
}
function getPresenceEntryForUser(user: User, status: GatewayPresenceStatus): GatewayPresenceEntry {
return {
user: {
id: user.id,
username: user.username
},
status
}
}
// The initial presence entries are sent right when the user connects.
// In the future, each user will have their own list of channels that they can join and leave.
// In that case, we will send the presence entries to a certain user only for the channels they're in.
function getInitialPresenceEntries(): GatewayPresenceEntry[] {
const entries: GatewayPresenceEntry[] = [];
sessionsByUserId.forEach((wsList: WebSocket[], userId: number) => {
if (wsList.length < 1)
return;
const firstWs = wsList[0];
if (firstWs.state.ready && firstWs.state.user) {
entries.push(getPresenceEntryForUser(firstWs.state.user, GatewayPresenceStatus.Online));
}
});
return entries;
}
export default function(server: Server) {
const wss = new WebSocketServer({ server });
const batchInterval = setInterval(() => {
wss.clients.forEach((e) => {
const now = performance.now();
if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) {
if (!e.state.ready) {
return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT);
}
if (!e.state.alive) {
return closeWithError(e, gatewayErrors.NO_PING);
}
e.state.messagesSinceLastCheck = 0;
}
});
}, GATEWAY_BATCH_INTERVAL);
wss.on("close", () => {
console.warn("gateway: websocket server closed");
console.warn("gateway: clearing batch interval due to websocket server close");
clearInterval(batchInterval);
});
wss.on("connection", (ws) => {
ws.state = {
user: undefined,
alive: false,
ready: false,
lastAliveCheck: performance.now(),
dispatchChannels: new Set(),
messagesSinceLastCheck: 0
};
sendPayload(ws, {
t: GatewayPayloadType.Hello,
d: {
pingInterval: GATEWAY_PING_INTERVAL
}
});
ws.on("close", () => {
clientUnsubscribeAll(ws);
ws.state.ready = false;
if (ws.state.user && ws.state.user.id) {
const sessions = sessionsByUserId.get(ws.state.user.id);
if (sessions) {
const index = sessions.indexOf(ws);
sessions.splice(index, 1);
if (sessions.length < 1) {
sessionsByUserId.delete(ws.state.user.id);
// user no longer has any sessions, update presence
dispatch("*", {
t: GatewayPayloadType.PresenceUpdate,
d: [getPresenceEntryForUser(ws.state.user, GatewayPresenceStatus.Offline)]
});
}
}
}
});
ws.on("message", async (rawData, isBinary) => {
if (isBinary) {
return closeWithBadPayload(ws, "Binary messages are not supported");
}
ws.state.messagesSinceLastCheck++;
if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) {
return closeWithError(ws, gatewayErrors.FLOODING);
}
const stringData = rawData.toString();
if (stringData.length > 2048) {
return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE);
}
const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData));
if (!payload) {
return closeWithBadPayload(ws, "Invalid JSON or message does not match schema");
}
switch (payload.t) {
case GatewayPayloadType.Authenticate: {
if (ws.state.ready) {
return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED);
}
const token = payload.d;
if (typeof token !== "string") {
return closeWithBadPayload(ws, "d: expected string");
}
const user = await decodeTokenOrNull(token);
if (!user) {
return closeWithError(ws, gatewayErrors.BAD_AUTH);
}
let sessions = sessionsByUserId.get(user.id);
if (sessions) {
if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) {
return closeWithError(ws, gatewayErrors.TOO_MANY_SESSIONS);
}
} else {
sessions = [];
sessionsByUserId.set(user.id, sessions);
}
sessions.push(ws);
// TODO: each user should have their own list of channels that they join
const channels = await query("SELECT id, name, owner_id FROM channels ORDER BY id ASC");
if (!channels) {
return closeWithError(ws, gatewayErrors.GOT_NO_DATABASE_DATA);
}
clientSubscribe(ws, "*");
channels.rows.forEach(c => {
clientSubscribe(ws, `channel:${c.id}`);
});
ws.state.user = user;
// first session, notify others that we are online
if (sessions.length === 1) {
dispatch("*", {
t: GatewayPayloadType.PresenceUpdate,
d: [getPresenceEntryForUser(ws.state.user, GatewayPresenceStatus.Online)]
});
}
ws.state.ready = true;
sendPayload(ws, {
t: GatewayPayloadType.Ready,
d: {
user: getPublicUserObject(ws.state.user),
channels: channels.rows,
presence: getInitialPresenceEntries()
}
});
break;
}
case GatewayPayloadType.Ping: {
if (payload.d !== 0) {
return closeWithBadPayload(ws, "d: expected numeric '0'");
}
if (!ws.state.ready) {
return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED);
}
// TODO: also check session here
ws.state.alive = true;
break;
}
default: {
return closeWithBadPayload(ws, "t: unknown type");
}
}
});
});
};