2022-04-10 01:22:07 +03:00
|
|
|
import { Server } from "node:http";
|
|
|
|
import { performance } from "node:perf_hooks";
|
2022-04-10 21:10:19 +03:00
|
|
|
import { WebSocketServer, WebSocket } from "ws";
|
|
|
|
import { decodeTokenOrNull, getPublicUserObject } from "../auth";
|
|
|
|
import { query } from "../database";
|
2022-04-10 01:22:07 +03:00
|
|
|
import { gatewayErrors } from "../errors";
|
|
|
|
import { GatewayPayload } from "../types/gatewaypayload";
|
|
|
|
import { GatewayPayloadType } from "./gatewaypayloadtype";
|
|
|
|
|
2022-04-14 17:02:51 +03:00
|
|
|
const GATEWAY_BATCH_INTERVAL = 50000;
|
|
|
|
const GATEWAY_PING_INTERVAL = 40000;
|
2022-04-14 17:17:54 +03:00
|
|
|
const MAX_CLIENT_MESSAGES_PER_BATCH = 6; // TODO: how well does this work for weak connections?
|
2022-04-14 21:52:42 +03:00
|
|
|
const MAX_GATEWAY_SESSIONS_PER_USER = 5;
|
2022-04-10 01:22:07 +03:00
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
// mapping between a dispatch id and a websocket client
|
|
|
|
const dispatchChannels = new Map<string, Set<WebSocket>>();
|
2022-04-10 21:10:19 +03:00
|
|
|
|
2022-04-14 21:52:42 +03:00
|
|
|
// mapping between a user id and the websocket sessions it has
|
|
|
|
const sessionsByUserId = new Map<number, Set<WebSocket>>();
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
function clientSubscribe(ws: WebSocket, dispatchChannel: string) {
|
|
|
|
ws.state.dispatchChannels.add(dispatchChannel);
|
|
|
|
if (!dispatchChannels.get(dispatchChannel)) {
|
|
|
|
dispatchChannels.set(dispatchChannel, new Set());
|
2022-04-10 21:10:19 +03:00
|
|
|
}
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
dispatchChannels.get(dispatchChannel)?.add(ws);
|
2022-04-10 21:10:19 +03:00
|
|
|
}
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) {
|
2022-04-10 21:10:19 +03:00
|
|
|
if (!ws.state) return;
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
ws.state.dispatchChannels.delete(dispatchChannel);
|
2022-04-10 21:10:19 +03:00
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
const set = dispatchChannels.get(dispatchChannel);
|
2022-04-10 21:10:19 +03:00
|
|
|
if (!set) return;
|
|
|
|
|
|
|
|
set.delete(ws);
|
|
|
|
if (set.size < 1) {
|
2022-04-10 21:28:36 +03:00
|
|
|
dispatchChannels.delete(dispatchChannel);
|
2022-04-10 21:10:19 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-04-12 00:19:29 +03:00
|
|
|
export function dispatchChannelSubscribe(target: string, dispatchChannel: string) {
|
|
|
|
const set = dispatchChannels.get(target);
|
|
|
|
if (!set) return;
|
|
|
|
|
|
|
|
set.forEach(c => {
|
|
|
|
clientSubscribe(c, dispatchChannel);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-04-10 21:10:19 +03:00
|
|
|
function clientUnsubscribeAll(ws: WebSocket) {
|
|
|
|
if (!ws.state) return;
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
ws.state.dispatchChannels.forEach(e => {
|
|
|
|
const set = dispatchChannels.get(e);
|
2022-04-10 21:10:19 +03:00
|
|
|
if (!set) return;
|
|
|
|
|
|
|
|
set.delete(ws);
|
|
|
|
if (set && set.size < 1) {
|
2022-04-10 21:28:36 +03:00
|
|
|
dispatchChannels.delete(e);
|
2022-04-10 21:10:19 +03:00
|
|
|
}
|
|
|
|
});
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
ws.state.dispatchChannels = new Set();
|
2022-04-10 21:10:19 +03:00
|
|
|
}
|
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
export function dispatch(channel: string, message: GatewayPayload) {
|
|
|
|
const members = dispatchChannels.get(channel);
|
2022-04-10 21:10:19 +03:00
|
|
|
if (!members) return;
|
2022-04-12 00:02:43 +03:00
|
|
|
|
|
|
|
members.forEach(e => {
|
|
|
|
e.send(JSON.stringify(message));
|
|
|
|
});
|
2022-04-10 21:10:19 +03:00
|
|
|
}
|
|
|
|
|
2022-04-10 01:22:07 +03:00
|
|
|
function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) {
|
2022-04-10 21:10:19 +03:00
|
|
|
return ws.close(code, message);
|
2022-04-10 01:22:07 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
function closeWithBadPayload(ws: WebSocket, hint: string) {
|
|
|
|
return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`);
|
|
|
|
}
|
|
|
|
|
|
|
|
function parseJsonOrNull(payload: string): any {
|
|
|
|
try {
|
|
|
|
return JSON.parse(payload);
|
|
|
|
} catch (e) {
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// The function below ensures `payload` is of the GatewayPayload
|
|
|
|
// interface payload. If it does not match, null is returned.
|
|
|
|
function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null {
|
|
|
|
if (!payload) {
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
|
|
|
|
let foundT = false;
|
|
|
|
let foundD = false;
|
|
|
|
for (const [k, v] of Object.entries(payload)) {
|
|
|
|
if (k === "t" && typeof v === "number") {
|
|
|
|
foundT = true;
|
|
|
|
} else if (k === "d") {
|
|
|
|
foundD = true;
|
|
|
|
} else {
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!foundT || !foundD) {
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
const asPayload = payload as GatewayPayload;
|
|
|
|
return asPayload;
|
|
|
|
}
|
|
|
|
|
|
|
|
function sendPayload(ws: WebSocket, payload: GatewayPayload) {
|
|
|
|
ws.send(JSON.stringify(payload));
|
|
|
|
}
|
|
|
|
|
|
|
|
export default function(server: Server) {
|
|
|
|
const wss = new WebSocketServer({ server });
|
|
|
|
|
2022-04-10 21:10:19 +03:00
|
|
|
const batchInterval = setInterval(() => {
|
2022-04-10 01:22:07 +03:00
|
|
|
wss.clients.forEach((e) => {
|
|
|
|
const now = performance.now();
|
|
|
|
if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) {
|
|
|
|
if (!e.state.ready) {
|
|
|
|
return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT);
|
|
|
|
}
|
2022-04-10 21:10:19 +03:00
|
|
|
if (!e.state.alive) {
|
|
|
|
return closeWithError(e, gatewayErrors.NO_PING);
|
|
|
|
}
|
2022-04-14 17:17:54 +03:00
|
|
|
e.state.messagesSinceLastCheck = 0;
|
2022-04-10 01:22:07 +03:00
|
|
|
}
|
|
|
|
});
|
|
|
|
}, GATEWAY_BATCH_INTERVAL);
|
|
|
|
|
2022-04-10 21:10:19 +03:00
|
|
|
wss.on("close", () => {
|
|
|
|
console.warn("gateway: websocket server closed");
|
|
|
|
console.warn("gateway: clearing batch interval due to websocket server close");
|
|
|
|
clearInterval(batchInterval);
|
|
|
|
});
|
|
|
|
|
2022-04-10 01:22:07 +03:00
|
|
|
wss.on("connection", (ws) => {
|
|
|
|
ws.state = {
|
|
|
|
user: undefined,
|
|
|
|
alive: false,
|
|
|
|
ready: false,
|
2022-04-10 21:10:19 +03:00
|
|
|
lastAliveCheck: performance.now(),
|
2022-04-14 17:17:54 +03:00
|
|
|
dispatchChannels: new Set(),
|
|
|
|
messagesSinceLastCheck: 0
|
2022-04-10 01:22:07 +03:00
|
|
|
};
|
|
|
|
|
|
|
|
sendPayload(ws, {
|
|
|
|
t: GatewayPayloadType.Hello,
|
|
|
|
d: {
|
|
|
|
pingInterval: GATEWAY_PING_INTERVAL
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
2022-04-10 21:10:19 +03:00
|
|
|
ws.on("close", () => {
|
|
|
|
clientUnsubscribeAll(ws);
|
2022-04-14 21:52:42 +03:00
|
|
|
if (ws.state.user && ws.state.user.id) {
|
|
|
|
const sessions = sessionsByUserId.get(ws.state.user.id);
|
|
|
|
if (sessions) {
|
|
|
|
sessions.delete(ws);
|
|
|
|
if (sessions.size < 1) {
|
|
|
|
sessionsByUserId.delete(ws.state.user.id);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2022-04-10 21:10:19 +03:00
|
|
|
});
|
|
|
|
|
2022-04-10 01:22:07 +03:00
|
|
|
ws.on("message", async (rawData, isBinary) => {
|
|
|
|
if (isBinary) {
|
|
|
|
return closeWithBadPayload(ws, "Binary messages are not supported");
|
|
|
|
}
|
|
|
|
|
2022-04-14 17:17:54 +03:00
|
|
|
ws.state.messagesSinceLastCheck++;
|
|
|
|
if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) {
|
|
|
|
return closeWithError(ws, gatewayErrors.FLOODING);
|
|
|
|
}
|
2022-04-14 21:29:46 +03:00
|
|
|
|
|
|
|
const stringData = rawData.toString();
|
|
|
|
if (stringData.length > 2048) {
|
|
|
|
return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE);
|
|
|
|
}
|
2022-04-14 17:17:54 +03:00
|
|
|
|
2022-04-14 21:29:46 +03:00
|
|
|
const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData));
|
2022-04-10 01:22:07 +03:00
|
|
|
if (!payload) {
|
|
|
|
return closeWithBadPayload(ws, "Invalid JSON or message does not match schema");
|
|
|
|
}
|
|
|
|
|
|
|
|
switch (payload.t) {
|
|
|
|
case GatewayPayloadType.Authenticate: {
|
2022-04-14 21:10:05 +03:00
|
|
|
if (ws.state.ready) {
|
|
|
|
return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED);
|
|
|
|
}
|
|
|
|
|
2022-04-10 01:22:07 +03:00
|
|
|
const token = payload.d;
|
|
|
|
if (typeof token !== "string") {
|
|
|
|
return closeWithBadPayload(ws, "d: expected string");
|
|
|
|
}
|
|
|
|
const user = await decodeTokenOrNull(token);
|
|
|
|
if (!user) {
|
|
|
|
return closeWithError(ws, gatewayErrors.BAD_AUTH);
|
|
|
|
}
|
2022-04-14 21:52:42 +03:00
|
|
|
|
|
|
|
let sessions = sessionsByUserId.get(user.id);
|
|
|
|
if (sessions) {
|
|
|
|
if ((sessions.size + 1) > MAX_GATEWAY_SESSIONS_PER_USER) {
|
|
|
|
return closeWithError(ws, gatewayErrors.TOO_MANY_SESSIONS);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
sessions = new Set();
|
|
|
|
sessionsByUserId.set(user.id, sessions);
|
|
|
|
}
|
|
|
|
sessions.add(ws);
|
|
|
|
|
|
|
|
// TODO: each user should have their own list of channels that they join
|
2022-04-17 17:56:03 +03:00
|
|
|
const channels = await query("SELECT id, name, owner_id FROM channels ORDER BY id ASC");
|
2022-04-10 21:10:19 +03:00
|
|
|
|
2022-04-10 21:28:36 +03:00
|
|
|
clientSubscribe(ws, "*");
|
2022-04-10 21:10:19 +03:00
|
|
|
channels.rows.forEach(c => {
|
|
|
|
clientSubscribe(ws, `channel:${c.id}`);
|
|
|
|
});
|
|
|
|
|
2022-04-10 01:22:07 +03:00
|
|
|
ws.state.user = user;
|
|
|
|
ws.state.ready = true;
|
2022-04-10 21:10:19 +03:00
|
|
|
|
2022-04-10 01:22:07 +03:00
|
|
|
sendPayload(ws, {
|
|
|
|
t: GatewayPayloadType.Ready,
|
|
|
|
d: {
|
2022-04-10 21:10:19 +03:00
|
|
|
user: getPublicUserObject(ws.state.user),
|
|
|
|
channels: channels.rows
|
2022-04-10 01:22:07 +03:00
|
|
|
}
|
|
|
|
})
|
|
|
|
break;
|
|
|
|
}
|
2022-04-10 21:10:19 +03:00
|
|
|
case GatewayPayloadType.Ping: {
|
2022-04-14 21:32:52 +03:00
|
|
|
if (payload.d !== 0) {
|
|
|
|
return closeWithBadPayload(ws, "d: expected numeric '0'");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: also check session here
|
2022-04-10 21:10:19 +03:00
|
|
|
ws.state.alive = true;
|
|
|
|
break;
|
|
|
|
}
|
2022-04-10 01:22:07 +03:00
|
|
|
default: {
|
|
|
|
return closeWithBadPayload(ws, "t: unknown type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
});
|
|
|
|
});
|
|
|
|
};
|