greatly refactor gateway

This commit is contained in:
hippoz 2023-08-08 15:36:51 +03:00
parent 689787247e
commit 52d253f2cf
Signed by: hippoz
GPG key ID: 56C4E02A85F2FBED
3 changed files with 313 additions and 339 deletions

View file

@ -1,6 +1,6 @@
import { Server } from "node:http"; import { Server } from "node:http";
import { performance } from "node:perf_hooks"; import { performance } from "node:perf_hooks";
import { WebSocketServer, WebSocket } from "ws"; import WebSocket, { WebSocketServer } from "ws";
import { decodeTokenOrNull, getPublicUserObject } from "../auth"; import { decodeTokenOrNull, getPublicUserObject } from "../auth";
import { query } from "../database"; import { query } from "../database";
import { gatewayErrors } from "../errors"; import { gatewayErrors } from "../errors";
@ -15,16 +15,21 @@ const GATEWAY_PING_INTERVAL = 40000;
const MAX_CLIENT_MESSAGES_PER_BATCH = 30; // TODO: how well does this work for weak connections? const MAX_CLIENT_MESSAGES_PER_BATCH = 30; // TODO: how well does this work for weak connections?
const MAX_GATEWAY_SESSIONS_PER_USER = 5; const MAX_GATEWAY_SESSIONS_PER_USER = 5;
// mapping between a dispatch id and a websocket client // mapping between a dispatch id and a websocket client
const dispatchChannels = new Map<string, Set<WebSocket>>(); const dispatchChannels = new Map<string, Set<GatewayClient>>();
// mapping between a user id and the websocket sessions it has // mapping between a user id and the websocket sessions it has
const sessionsByUserId = new Map<number, WebSocket[]>(); const sessionsByUserId = new Map<number, GatewayClient[]>();
// mapping between a dispatch id and a temporary handler // mapping between a dispatch id and a temporary handler
const dispatchTemporary = new Map<string, Set<(payload: GatewayPayload) => void>>(); const dispatchTemporary = new Map<string, Set<(payload: GatewayPayload) => void>>();
export function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) { // all clients
const gatewayClients = new Set<GatewayClient>();
function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) {
channels.forEach(c => { channels.forEach(c => {
if (!dispatchTemporary.get(c)) { if (!dispatchTemporary.get(c)) {
dispatchTemporary.set(c, new Set()); dispatchTemporary.set(c, new Set());
@ -64,65 +69,17 @@ export function waitForEvent(channels: string[], timeout: number) {
}); });
} }
function clientSubscribe(ws: WebSocket, dispatchChannel: string) { export function dispatch(channel: string, message: GatewayPayload | ((ws: GatewayClient | null) => GatewayPayload)) {
ws.state.dispatchChannels.add(dispatchChannel);
if (!dispatchChannels.get(dispatchChannel)) {
dispatchChannels.set(dispatchChannel, new Set());
}
dispatchChannels.get(dispatchChannel)?.add(ws);
}
function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) {
if (!ws.state) return;
ws.state.dispatchChannels.delete(dispatchChannel);
const set = dispatchChannels.get(dispatchChannel);
if (!set) return;
set.delete(ws);
if (set.size < 1) {
dispatchChannels.delete(dispatchChannel);
}
}
export function dispatchChannelSubscribe(target: string, dispatchChannel: string) {
const set = dispatchChannels.get(target);
if (!set) return;
set.forEach(c => {
clientSubscribe(c, dispatchChannel);
});
}
function clientUnsubscribeAll(ws: WebSocket) {
if (!ws.state) return;
ws.state.dispatchChannels.forEach(e => {
const set = dispatchChannels.get(e);
if (!set) return;
set.delete(ws);
if (set && set.size < 1) {
dispatchChannels.delete(e);
}
});
ws.state.dispatchChannels = new Set();
}
export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket | null) => GatewayPayload)) {
const members = dispatchChannels.get(channel); const members = dispatchChannels.get(channel);
if (!members) return; if (!members) return;
members.forEach(e => { members.forEach(e => {
if (e.state.ready) { if (e.ready) {
let data = message; let data = message;
if (typeof message === "function") { if (typeof message === "function") {
data = message(e); data = message(e);
} }
e.send(JSON.stringify(data)); e.send(data);
} }
}); });
@ -139,14 +96,34 @@ export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSoc
} }
} }
function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { export function dispatchChannelSubscribe(target: string, dispatchChannel: string) {
return ws.close(code, message); const set = dispatchChannels.get(target);
set?.forEach(c => c.subscribe(dispatchChannel));
} }
function closeWithBadPayload(ws: WebSocket, hint: string) {
return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`);
}
function getInitialPresenceEntries(): GatewayPresenceEntry[] {
// The initial presence entries are sent right when the user connects.
// In the future, each user will have their own list of channels that they can join and leave.
// In that case, we will send the presence entries to a certain user only for the channels they're in.
const entries: GatewayPresenceEntry[] = [];
sessionsByUserId.forEach((clients: GatewayClient[], userId: number) => {
if (clients.length < 1)
return;
const firstClient = clients[0];
if (firstClient.ready && firstClient.user) {
const entry = firstClient.getPresenceEntry(GatewayPresenceStatus.Online);
if (entry) {
entries.push(entry);
}
}
});
return entries;
}
function parseJsonOrNull(payload: string): any { function parseJsonOrNull(payload: string): any {
try { try {
return JSON.parse(payload); return JSON.parse(payload);
@ -155,8 +132,6 @@ function parseJsonOrNull(payload: string): any {
} }
} }
// The function below ensures `payload` is of the GatewayPayload
// interface payload. If it does not match, null is returned.
function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null { function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null {
if (!payload) { if (!payload) {
return null; return null;
@ -183,131 +158,147 @@ function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null {
return asPayload; return asPayload;
} }
function sendPayload(ws: WebSocket, payload: GatewayPayload) {
ws.send(JSON.stringify(payload));
}
function getPresenceEntryForConnection(ws: WebSocket, status: GatewayPresenceStatus): GatewayPresenceEntry | null { class GatewayClient {
if (!ws.state || !ws.state.user) { ws: WebSocket;
return null; user?: User;
ready: boolean;
alive: boolean;
lastAliveCheck: number;
clientDispatchChannels: Set<string>;
messagesSinceLastCheck: number;
bridgesTo?: string;
privacy?: string;
terms?: string;
constructor(ws: WebSocket) {
this.ws = ws;
this.user = undefined;
this.ready = false;
this.alive = false;
this.lastAliveCheck = performance.now();
this.clientDispatchChannels = new Set();
this.messagesSinceLastCheck = 0;
this.bridgesTo = undefined;
this.privacy = undefined;
this.terms = undefined;
gatewayClients.add(this);
this.ws.on("close", this.handleClose.bind(this));
this.ws.on("message", this.handleMessage.bind(this));
} }
const entry: GatewayPresenceEntry = { greet() {
user: { this.send({
id: ws.state.user.id,
username: ws.state.user.username,
avatar: ws.state.user.avatar
},
status
};
if (typeof ws.state.bridgesTo === "string") {
entry.bridgesTo = ws.state.bridgesTo;
}
if (typeof ws.state.privacy === "string") {
entry.privacy = ws.state.privacy;
}
if (typeof ws.state.terms === "string") {
entry.terms = ws.state.terms;
}
return entry;
}
// The initial presence entries are sent right when the user connects.
// In the future, each user will have their own list of channels that they can join and leave.
// In that case, we will send the presence entries to a certain user only for the channels they're in.
function getInitialPresenceEntries(): GatewayPresenceEntry[] {
const entries: GatewayPresenceEntry[] = [];
sessionsByUserId.forEach((wsList: WebSocket[], userId: number) => {
if (wsList.length < 1)
return;
const firstWs = wsList[0];
if (firstWs.state.ready && firstWs.state.user) {
const entry = getPresenceEntryForConnection(firstWs, GatewayPresenceStatus.Online);
if (entry) {
entries.push(entry);
}
}
});
return entries;
}
export default function(server: Server) {
const wss = new WebSocketServer({ server });
const batchInterval = setInterval(() => {
wss.clients.forEach((e) => {
const now = performance.now();
if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) {
if (!e.state.ready) {
return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT);
}
if (!e.state.alive) {
return closeWithError(e, gatewayErrors.NO_PING);
}
e.state.messagesSinceLastCheck = 0;
}
});
}, GATEWAY_BATCH_INTERVAL);
wss.on("close", () => {
console.warn("gateway: websocket server closed");
console.warn("gateway: clearing batch interval due to websocket server close");
clearInterval(batchInterval);
});
wss.on("connection", (ws) => {
ws.state = {
user: undefined,
alive: false,
ready: false,
lastAliveCheck: performance.now(),
dispatchChannels: new Set(),
messagesSinceLastCheck: 0,
bridgesTo: undefined
};
sendPayload(ws, {
t: GatewayPayloadType.Hello, t: GatewayPayloadType.Hello,
d: { d: {
pingInterval: GATEWAY_PING_INTERVAL pingInterval: GATEWAY_PING_INTERVAL
} }
}); });
}
ws.on("close", () => { subscribe(channel: string) {
clientUnsubscribeAll(ws); this.clientDispatchChannels.add(channel);
ws.state.ready = false; if (!dispatchChannels.get(channel)) {
if (ws.state.user && ws.state.user.id) { dispatchChannels.set(channel, new Set());
const sessions = sessionsByUserId.get(ws.state.user.id); }
dispatchChannels.get(channel)?.add(this);
}
unsubscribeAll() {
this.clientDispatchChannels.forEach((channel) => {
const set = dispatchChannels.get(channel);
if (!set) return;
set.delete(this);
if (set && set.size < 1) {
dispatchChannels.delete(channel);
}
});
this.clientDispatchChannels.clear();
}
send(payload: object) {
this.ws.send(JSON.stringify(payload));
}
closeWithError({ code, message }: { code: number, message: string }) {
this.ws.close(code, message);
}
closeWithBadPayload(hint: string) {
this.ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`);
}
getPresenceEntry(status: GatewayPresenceStatus): GatewayPresenceEntry | null {
if (!this.user || !this.ready) {
return null;
}
return {
user: {
id: this.user.id,
username: this.user.username,
avatar: this.user.avatar
},
status,
bridgesTo: this.bridgesTo,
privacy: this.privacy,
terms: this.terms,
};
}
handleClose() {
gatewayClients.delete(this);
this.unsubscribeAll();
this.ready = false;
if (this.user) {
const sessions = sessionsByUserId.get(this.user.id);
if (sessions) { if (sessions) {
const index = sessions.indexOf(ws); const index = sessions.indexOf(this);
sessions.splice(index, 1); sessions.splice(index, 1);
if (sessions.length < 1) { if (!sessions.length) {
sessionsByUserId.delete(ws.state.user.id); sessionsByUserId.delete(this.user.id);
// user no longer has any sessions, update presence // user no longer has any sessions, update presence
dispatch("*", { dispatch("*", {
t: GatewayPayloadType.PresenceUpdate, t: GatewayPayloadType.PresenceUpdate,
d: [getPresenceEntryForConnection(ws, GatewayPresenceStatus.Offline)] d: [
this.getPresenceEntry(GatewayPresenceStatus.Offline)
]
}); });
} }
} }
} }
}); }
ws.on("message", async (rawData: Buffer, isBinary) => { batchTick() {
ws.state.messagesSinceLastCheck++; const now = performance.now();
if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) { if ((now - this.lastAliveCheck) >= GATEWAY_PING_INTERVAL) {
return closeWithError(ws, gatewayErrors.FLOODING); if (!this.ready) {
return this.closeWithError(gatewayErrors.AUTHENTICATION_TIMEOUT);
} }
if (!this.alive) {
return this.closeWithError(gatewayErrors.NO_PING);
}
this.messagesSinceLastCheck = 0;
}
}
async handleMessage(rawData: Buffer, isBinary: boolean) {
this.messagesSinceLastCheck++;
if (rawData.byteLength >= maxGatewayPayloadByteLength) { if (rawData.byteLength >= maxGatewayPayloadByteLength) {
return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE); return this.closeWithError(gatewayErrors.PAYLOAD_TOO_LARGE);
} }
if (this.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) {
return this.closeWithError(gatewayErrors.FLOODING);
}
let stringData: string; let stringData: string;
let binaryStream: Buffer | null = null; let binaryStream: Buffer | null = null;
if (isBinary) { if (isBinary) {
@ -326,7 +317,7 @@ export default function(server: Server) {
} }
} }
if (!jsonSlice) { if (!jsonSlice) {
return closeWithBadPayload(ws, "Did not find newline to delimit JSON from binary stream. JSON payload may be too large, or newline may be missing."); return this.closeWithBadPayload("Did not find newline to delimit JSON from binary stream. JSON payload may be too large, or newline may be missing.");
} }
binaryStream = rawData.subarray(jsonOffset, rawData.byteLength); binaryStream = rawData.subarray(jsonOffset, rawData.byteLength);
@ -336,50 +327,38 @@ export default function(server: Server) {
} }
if (stringData.length > maxGatewayJsonStringLength) { if (stringData.length > maxGatewayJsonStringLength) {
return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE); return this.closeWithError(gatewayErrors.PAYLOAD_TOO_LARGE);
} }
const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData)); const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData));
if (!payload) { if (!payload) {
return closeWithBadPayload(ws, "Invalid JSON or message does not match schema"); return this.closeWithBadPayload("Invalid JSON or message does not match schema");
} }
switch (payload.t) { switch (payload.t) {
case GatewayPayloadType.Authenticate: { case GatewayPayloadType.Authenticate: {
if (ws.state.ready) { if (this.ready) {
return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED); return this.closeWithError(gatewayErrors.ALREADY_AUTHENTICATED);
} }
const authData = payload.d; const authData = payload.d;
if (typeof authData !== "object") { if (typeof authData !== "object") {
return closeWithBadPayload(ws, "d: expected object"); return this.closeWithBadPayload("d: expected object");
} }
if (typeof authData.token !== "string") { if (typeof authData.token !== "string") {
return closeWithBadPayload(ws, "d: invalid field 'token'"); return this.closeWithBadPayload("d: invalid field 'token'");
}
if (typeof authData.bridgesTo !== "undefined" && typeof authData.bridgesTo !== "string" && authData.bridgesTo.length > 40) {
return closeWithBadPayload(ws, "d: invalid field 'bridgesTo'");
}
if (typeof authData.privacy !== "undefined" && typeof authData.privacy !== "string" && authData.privacy.length > 200) {
return closeWithBadPayload(ws, "d: invalid field 'privacy'");
}
if (typeof authData.terms !== "undefined" && typeof authData.terms !== "string" && authData.terms.length > 200) {
return closeWithBadPayload(ws, "d: invalid field 'terms'");
} }
const user = await decodeTokenOrNull(authData.token); const user = await decodeTokenOrNull(authData.token);
if (!user) { if (!user) {
return closeWithError(ws, gatewayErrors.BAD_AUTH); return this.closeWithError(gatewayErrors.BAD_AUTH);
} }
let sessions = sessionsByUserId.get(user.id); let sessions = sessionsByUserId.get(user.id);
if (sessions) { if (sessions) {
if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) { if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) {
return closeWithError(ws, gatewayErrors.TOO_MANY_SESSIONS); return this.closeWithError(gatewayErrors.TOO_MANY_SESSIONS);
} }
} }
@ -389,42 +368,39 @@ export default function(server: Server) {
query("SELECT id, name, owner_id, avatar, created_at FROM communities ORDER BY id ASC"), query("SELECT id, name, owner_id, avatar, created_at FROM communities ORDER BY id ASC"),
]); ]);
if (!channels || !communities) { if (!channels || !communities) {
return closeWithError(ws, gatewayErrors.GOT_NO_DATABASE_DATA); return this.closeWithError(gatewayErrors.GOT_NO_DATABASE_DATA);
} }
if (!sessions) { if (!sessions) {
sessions = []; sessions = [];
sessionsByUserId.set(user.id, sessions); sessionsByUserId.set(user.id, sessions);
} }
sessions.push(ws); sessions.push(this);
clientSubscribe(ws, "*"); this.subscribe("*");
for (let i = 0; i < channels.rows.length; i++) { for (let i = 0; i < channels.rows.length; i++) {
clientSubscribe(ws, `channel:${channels.rows[i].id}`); this.subscribe(`channel:${channels.rows[i].id}`);
} }
for (let i = 0; i < communities.rows.length; i++) { for (let i = 0; i < communities.rows.length; i++) {
clientSubscribe(ws, `community:${communities.rows[i].id}`); this.subscribe(`community:${communities.rows[i].id}`);
} }
ws.state.user = user; this.user = user;
ws.state.bridgesTo = authData.bridgesTo;
ws.state.privacy = authData.privacy;
ws.state.terms = authData.terms;
// first session, notify others that we are online // first session, notify others that we are online
if (sessions.length === 1) { if (sessions.length === 1) {
dispatch("*", { dispatch("*", {
t: GatewayPayloadType.PresenceUpdate, t: GatewayPayloadType.PresenceUpdate,
d: [getPresenceEntryForConnection(ws, GatewayPresenceStatus.Online)] d: [this.getPresenceEntry(GatewayPresenceStatus.Online)]
}); });
} }
ws.state.ready = true; this.ready = true;
sendPayload(ws, { this.send({
t: GatewayPayloadType.Ready, t: GatewayPayloadType.Ready,
d: { d: {
user: getPublicUserObject(ws.state.user), user: getPublicUserObject(this.user),
channels: channels.rows, channels: channels.rows,
communities: communities.rows, communities: communities.rows,
presence: getInitialPresenceEntries() presence: getInitialPresenceEntries()
@ -434,26 +410,25 @@ export default function(server: Server) {
} }
case GatewayPayloadType.Ping: { case GatewayPayloadType.Ping: {
if (payload.d !== 0) { if (payload.d !== 0) {
return closeWithBadPayload(ws, "d: expected numeric '0'"); return this.closeWithBadPayload("d: expected numeric '0'");
} }
if (!ws.state.ready) { if (!this.ready) {
return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED); return this.closeWithError(gatewayErrors.NOT_AUTHENTICATED);
} }
// TODO: also check session here this.alive = true;
ws.state.alive = true;
break; break;
} }
case GatewayPayloadType.RPCSignal: /* through */ case GatewayPayloadType.RPCSignal: /* through */
case GatewayPayloadType.RPCRequest: { case GatewayPayloadType.RPCRequest: {
if (!ws.state.ready || !ws.state.user) { if (!this.ready || !this.user) {
return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED); return this.closeWithError(gatewayErrors.NOT_AUTHENTICATED);
} }
// RPCSignal is like RPCRequest however it does not send RPC method output unless there is an error // RPCSignal is like RPCRequest however it does not send RPC method output unless there is an error
processMethodBatch(ws.state.user, payload.d, (payload.t === GatewayPayloadType.RPCSignal ? true : false), binaryStream).then((results) => { processMethodBatch(this.user, payload.d, (payload.t === GatewayPayloadType.RPCSignal ? true : false), binaryStream).then((results) => {
sendPayload(ws, { this.send({
t: GatewayPayloadType.RPCResponse, t: GatewayPayloadType.RPCResponse,
d: results, d: results,
s: payload.s s: payload.s
@ -462,9 +437,26 @@ export default function(server: Server) {
break; break;
} }
default: { default: {
return closeWithBadPayload(ws, "t: unknown type"); return this.closeWithBadPayload("t: unknown type");
} }
} }
}
}
export default function(server: Server) {
const wss = new WebSocketServer({ server });
const batchInterval = setInterval(() => {
gatewayClients.forEach(client => client.batchTick());
}, GATEWAY_BATCH_INTERVAL);
wss.on("close", () => {
console.error("gateway: websocket server closed");
clearInterval(batchInterval);
}); });
wss.on("connection", (ws) => {
const client = new GatewayClient(ws);
client.greet();
}); });
}; };

View file

@ -1,11 +0,0 @@
interface GatewayClientState {
user?: User;
ready: boolean,
alive: boolean,
lastAliveCheck: number,
dispatchChannels: Set<string>,
messagesSinceLastCheck: number,
bridgesTo?: string
privacy?: string,
terms?: string,
}

7
src/types/ws.d.ts vendored
View file

@ -1,7 +0,0 @@
import ws from 'ws';
declare module 'ws' {
export interface WebSocket extends ws {
state: GatewayClientState;
}
}