greatly refactor gateway

This commit is contained in:
hippoz 2023-08-08 15:36:51 +03:00
parent 689787247e
commit 52d253f2cf
Signed by: hippoz
GPG key ID: 56C4E02A85F2FBED
3 changed files with 313 additions and 339 deletions

View file

@ -1,6 +1,6 @@
import { Server } from "node:http";
import { performance } from "node:perf_hooks";
import { WebSocketServer, WebSocket } from "ws";
import WebSocket, { WebSocketServer } from "ws";
import { decodeTokenOrNull, getPublicUserObject } from "../auth";
import { query } from "../database";
import { gatewayErrors } from "../errors";
@ -15,16 +15,21 @@ const GATEWAY_PING_INTERVAL = 40000;
const MAX_CLIENT_MESSAGES_PER_BATCH = 30; // TODO: how well does this work for weak connections?
const MAX_GATEWAY_SESSIONS_PER_USER = 5;
// mapping between a dispatch id and a websocket client
const dispatchChannels = new Map<string, Set<WebSocket>>();
const dispatchChannels = new Map<string, Set<GatewayClient>>();
// mapping between a user id and the websocket sessions it has
const sessionsByUserId = new Map<number, WebSocket[]>();
const sessionsByUserId = new Map<number, GatewayClient[]>();
// mapping between a dispatch id and a temporary handler
const dispatchTemporary = new Map<string, Set<(payload: GatewayPayload) => void>>();
export function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) {
// all clients
const gatewayClients = new Set<GatewayClient>();
function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) {
channels.forEach(c => {
if (!dispatchTemporary.get(c)) {
dispatchTemporary.set(c, new Set());
@ -64,65 +69,17 @@ export function waitForEvent(channels: string[], timeout: number) {
});
}
function clientSubscribe(ws: WebSocket, dispatchChannel: string) {
ws.state.dispatchChannels.add(dispatchChannel);
if (!dispatchChannels.get(dispatchChannel)) {
dispatchChannels.set(dispatchChannel, new Set());
}
dispatchChannels.get(dispatchChannel)?.add(ws);
}
function clientUnsubscribe(ws: WebSocket, dispatchChannel: string) {
if (!ws.state) return;
ws.state.dispatchChannels.delete(dispatchChannel);
const set = dispatchChannels.get(dispatchChannel);
if (!set) return;
set.delete(ws);
if (set.size < 1) {
dispatchChannels.delete(dispatchChannel);
}
}
export function dispatchChannelSubscribe(target: string, dispatchChannel: string) {
const set = dispatchChannels.get(target);
if (!set) return;
set.forEach(c => {
clientSubscribe(c, dispatchChannel);
});
}
function clientUnsubscribeAll(ws: WebSocket) {
if (!ws.state) return;
ws.state.dispatchChannels.forEach(e => {
const set = dispatchChannels.get(e);
if (!set) return;
set.delete(ws);
if (set && set.size < 1) {
dispatchChannels.delete(e);
}
});
ws.state.dispatchChannels = new Set();
}
export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket | null) => GatewayPayload)) {
export function dispatch(channel: string, message: GatewayPayload | ((ws: GatewayClient | null) => GatewayPayload)) {
const members = dispatchChannels.get(channel);
if (!members) return;
members.forEach(e => {
if (e.state.ready) {
if (e.ready) {
let data = message;
if (typeof message === "function") {
data = message(e);
}
e.send(JSON.stringify(data));
e.send(data);
}
});
@ -139,14 +96,34 @@ export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSoc
}
}
function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) {
return ws.close(code, message);
export function dispatchChannelSubscribe(target: string, dispatchChannel: string) {
const set = dispatchChannels.get(target);
set?.forEach(c => c.subscribe(dispatchChannel));
}
function closeWithBadPayload(ws: WebSocket, hint: string) {
return ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`);
}
function getInitialPresenceEntries(): GatewayPresenceEntry[] {
// The initial presence entries are sent right when the user connects.
// In the future, each user will have their own list of channels that they can join and leave.
// In that case, we will send the presence entries to a certain user only for the channels they're in.
const entries: GatewayPresenceEntry[] = [];
sessionsByUserId.forEach((clients: GatewayClient[], userId: number) => {
if (clients.length < 1)
return;
const firstClient = clients[0];
if (firstClient.ready && firstClient.user) {
const entry = firstClient.getPresenceEntry(GatewayPresenceStatus.Online);
if (entry) {
entries.push(entry);
}
}
});
return entries;
}
function parseJsonOrNull(payload: string): any {
try {
return JSON.parse(payload);
@ -155,8 +132,6 @@ function parseJsonOrNull(payload: string): any {
}
}
// The function below ensures `payload` is of the GatewayPayload
// interface payload. If it does not match, null is returned.
function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null {
if (!payload) {
return null;
@ -183,288 +158,305 @@ function ensureFormattedGatewayPayload(payload: any): GatewayPayload | null {
return asPayload;
}
function sendPayload(ws: WebSocket, payload: GatewayPayload) {
ws.send(JSON.stringify(payload));
}
function getPresenceEntryForConnection(ws: WebSocket, status: GatewayPresenceStatus): GatewayPresenceEntry | null {
if (!ws.state || !ws.state.user) {
return null;
class GatewayClient {
ws: WebSocket;
user?: User;
ready: boolean;
alive: boolean;
lastAliveCheck: number;
clientDispatchChannels: Set<string>;
messagesSinceLastCheck: number;
bridgesTo?: string;
privacy?: string;
terms?: string;
constructor(ws: WebSocket) {
this.ws = ws;
this.user = undefined;
this.ready = false;
this.alive = false;
this.lastAliveCheck = performance.now();
this.clientDispatchChannels = new Set();
this.messagesSinceLastCheck = 0;
this.bridgesTo = undefined;
this.privacy = undefined;
this.terms = undefined;
gatewayClients.add(this);
this.ws.on("close", this.handleClose.bind(this));
this.ws.on("message", this.handleMessage.bind(this));
}
const entry: GatewayPresenceEntry = {
user: {
id: ws.state.user.id,
username: ws.state.user.username,
avatar: ws.state.user.avatar
},
status
};
if (typeof ws.state.bridgesTo === "string") {
entry.bridgesTo = ws.state.bridgesTo;
}
if (typeof ws.state.privacy === "string") {
entry.privacy = ws.state.privacy;
}
if (typeof ws.state.terms === "string") {
entry.terms = ws.state.terms;
greet() {
this.send({
t: GatewayPayloadType.Hello,
d: {
pingInterval: GATEWAY_PING_INTERVAL
}
});
}
return entry;
}
subscribe(channel: string) {
this.clientDispatchChannels.add(channel);
if (!dispatchChannels.get(channel)) {
dispatchChannels.set(channel, new Set());
}
dispatchChannels.get(channel)?.add(this);
}
// The initial presence entries are sent right when the user connects.
// In the future, each user will have their own list of channels that they can join and leave.
// In that case, we will send the presence entries to a certain user only for the channels they're in.
function getInitialPresenceEntries(): GatewayPresenceEntry[] {
const entries: GatewayPresenceEntry[] = [];
unsubscribeAll() {
this.clientDispatchChannels.forEach((channel) => {
const set = dispatchChannels.get(channel);
if (!set) return;
set.delete(this);
if (set && set.size < 1) {
dispatchChannels.delete(channel);
}
});
this.clientDispatchChannels.clear();
}
sessionsByUserId.forEach((wsList: WebSocket[], userId: number) => {
if (wsList.length < 1)
return;
send(payload: object) {
this.ws.send(JSON.stringify(payload));
}
const firstWs = wsList[0];
if (firstWs.state.ready && firstWs.state.user) {
const entry = getPresenceEntryForConnection(firstWs, GatewayPresenceStatus.Online);
if (entry) {
entries.push(entry);
closeWithError({ code, message }: { code: number, message: string }) {
this.ws.close(code, message);
}
closeWithBadPayload(hint: string) {
this.ws.close(gatewayErrors.BAD_PAYLOAD.code, `${gatewayErrors.BAD_PAYLOAD.message}: ${hint}`);
}
getPresenceEntry(status: GatewayPresenceStatus): GatewayPresenceEntry | null {
if (!this.user || !this.ready) {
return null;
}
return {
user: {
id: this.user.id,
username: this.user.username,
avatar: this.user.avatar
},
status,
bridgesTo: this.bridgesTo,
privacy: this.privacy,
terms: this.terms,
};
}
handleClose() {
gatewayClients.delete(this);
this.unsubscribeAll();
this.ready = false;
if (this.user) {
const sessions = sessionsByUserId.get(this.user.id);
if (sessions) {
const index = sessions.indexOf(this);
sessions.splice(index, 1);
if (!sessions.length) {
sessionsByUserId.delete(this.user.id);
// user no longer has any sessions, update presence
dispatch("*", {
t: GatewayPayloadType.PresenceUpdate,
d: [
this.getPresenceEntry(GatewayPresenceStatus.Offline)
]
});
}
}
}
});
}
return entries;
batchTick() {
const now = performance.now();
if ((now - this.lastAliveCheck) >= GATEWAY_PING_INTERVAL) {
if (!this.ready) {
return this.closeWithError(gatewayErrors.AUTHENTICATION_TIMEOUT);
}
if (!this.alive) {
return this.closeWithError(gatewayErrors.NO_PING);
}
this.messagesSinceLastCheck = 0;
}
}
async handleMessage(rawData: Buffer, isBinary: boolean) {
this.messagesSinceLastCheck++;
if (rawData.byteLength >= maxGatewayPayloadByteLength) {
return this.closeWithError(gatewayErrors.PAYLOAD_TOO_LARGE);
}
if (this.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) {
return this.closeWithError(gatewayErrors.FLOODING);
}
let stringData: string;
let binaryStream: Buffer | null = null;
if (isBinary) {
// Binary frames are used in order combine our text data (JSON) with binary data.
// This is especially useful for calling RPC methods that, for example, upload files.
// The format is: [json payload]\n[begin binary stream]
let jsonSlice;
let jsonOffset = -1;
for (let i = 0; i < maxGatewayJsonStringByteLength; i++) {
if (rawData.readUInt8(i) === 0x0A) {
// hit newline
jsonSlice = rawData.subarray(0, i);
jsonOffset = i + 1;
break;
}
}
if (!jsonSlice) {
return this.closeWithBadPayload("Did not find newline to delimit JSON from binary stream. JSON payload may be too large, or newline may be missing.");
}
binaryStream = rawData.subarray(jsonOffset, rawData.byteLength);
stringData = jsonSlice.toString();
} else {
stringData = rawData.toString();
}
if (stringData.length > maxGatewayJsonStringLength) {
return this.closeWithError(gatewayErrors.PAYLOAD_TOO_LARGE);
}
const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData));
if (!payload) {
return this.closeWithBadPayload("Invalid JSON or message does not match schema");
}
switch (payload.t) {
case GatewayPayloadType.Authenticate: {
if (this.ready) {
return this.closeWithError(gatewayErrors.ALREADY_AUTHENTICATED);
}
const authData = payload.d;
if (typeof authData !== "object") {
return this.closeWithBadPayload("d: expected object");
}
if (typeof authData.token !== "string") {
return this.closeWithBadPayload("d: invalid field 'token'");
}
const user = await decodeTokenOrNull(authData.token);
if (!user) {
return this.closeWithError(gatewayErrors.BAD_AUTH);
}
let sessions = sessionsByUserId.get(user.id);
if (sessions) {
if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) {
return this.closeWithError(gatewayErrors.TOO_MANY_SESSIONS);
}
}
// TODO: each user should have their own list of channels that they join
const [channels, communities] = await Promise.all([
query("SELECT id, name, owner_id, community_id FROM channels ORDER BY id ASC"),
query("SELECT id, name, owner_id, avatar, created_at FROM communities ORDER BY id ASC"),
]);
if (!channels || !communities) {
return this.closeWithError(gatewayErrors.GOT_NO_DATABASE_DATA);
}
if (!sessions) {
sessions = [];
sessionsByUserId.set(user.id, sessions);
}
sessions.push(this);
this.subscribe("*");
for (let i = 0; i < channels.rows.length; i++) {
this.subscribe(`channel:${channels.rows[i].id}`);
}
for (let i = 0; i < communities.rows.length; i++) {
this.subscribe(`community:${communities.rows[i].id}`);
}
this.user = user;
// first session, notify others that we are online
if (sessions.length === 1) {
dispatch("*", {
t: GatewayPayloadType.PresenceUpdate,
d: [this.getPresenceEntry(GatewayPresenceStatus.Online)]
});
}
this.ready = true;
this.send({
t: GatewayPayloadType.Ready,
d: {
user: getPublicUserObject(this.user),
channels: channels.rows,
communities: communities.rows,
presence: getInitialPresenceEntries()
}
});
break;
}
case GatewayPayloadType.Ping: {
if (payload.d !== 0) {
return this.closeWithBadPayload("d: expected numeric '0'");
}
if (!this.ready) {
return this.closeWithError(gatewayErrors.NOT_AUTHENTICATED);
}
this.alive = true;
break;
}
case GatewayPayloadType.RPCSignal: /* through */
case GatewayPayloadType.RPCRequest: {
if (!this.ready || !this.user) {
return this.closeWithError(gatewayErrors.NOT_AUTHENTICATED);
}
// RPCSignal is like RPCRequest however it does not send RPC method output unless there is an error
processMethodBatch(this.user, payload.d, (payload.t === GatewayPayloadType.RPCSignal ? true : false), binaryStream).then((results) => {
this.send({
t: GatewayPayloadType.RPCResponse,
d: results,
s: payload.s
});
});
break;
}
default: {
return this.closeWithBadPayload("t: unknown type");
}
}
}
}
export default function(server: Server) {
const wss = new WebSocketServer({ server });
const batchInterval = setInterval(() => {
wss.clients.forEach((e) => {
const now = performance.now();
if (e.state && (now - e.state.lastAliveCheck) >= GATEWAY_PING_INTERVAL) {
if (!e.state.ready) {
return closeWithError(e, gatewayErrors.AUTHENTICATION_TIMEOUT);
}
if (!e.state.alive) {
return closeWithError(e, gatewayErrors.NO_PING);
}
e.state.messagesSinceLastCheck = 0;
}
});
gatewayClients.forEach(client => client.batchTick());
}, GATEWAY_BATCH_INTERVAL);
wss.on("close", () => {
console.warn("gateway: websocket server closed");
console.warn("gateway: clearing batch interval due to websocket server close");
console.error("gateway: websocket server closed");
clearInterval(batchInterval);
});
wss.on("connection", (ws) => {
ws.state = {
user: undefined,
alive: false,
ready: false,
lastAliveCheck: performance.now(),
dispatchChannels: new Set(),
messagesSinceLastCheck: 0,
bridgesTo: undefined
};
sendPayload(ws, {
t: GatewayPayloadType.Hello,
d: {
pingInterval: GATEWAY_PING_INTERVAL
}
});
ws.on("close", () => {
clientUnsubscribeAll(ws);
ws.state.ready = false;
if (ws.state.user && ws.state.user.id) {
const sessions = sessionsByUserId.get(ws.state.user.id);
if (sessions) {
const index = sessions.indexOf(ws);
sessions.splice(index, 1);
if (sessions.length < 1) {
sessionsByUserId.delete(ws.state.user.id);
// user no longer has any sessions, update presence
dispatch("*", {
t: GatewayPayloadType.PresenceUpdate,
d: [getPresenceEntryForConnection(ws, GatewayPresenceStatus.Offline)]
});
}
}
}
});
ws.on("message", async (rawData: Buffer, isBinary) => {
ws.state.messagesSinceLastCheck++;
if (ws.state.messagesSinceLastCheck > MAX_CLIENT_MESSAGES_PER_BATCH) {
return closeWithError(ws, gatewayErrors.FLOODING);
}
if (rawData.byteLength >= maxGatewayPayloadByteLength) {
return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE);
}
let stringData: string;
let binaryStream: Buffer | null = null;
if (isBinary) {
// Binary frames are used in order combine our text data (JSON) with binary data.
// This is especially useful for calling RPC methods that, for example, upload files.
// The format is: [json payload]\n[begin binary stream]
let jsonSlice;
let jsonOffset = -1;
for (let i = 0; i < maxGatewayJsonStringByteLength; i++) {
if (rawData.readUInt8(i) === 0x0A) {
// hit newline
jsonSlice = rawData.subarray(0, i);
jsonOffset = i + 1;
break;
}
}
if (!jsonSlice) {
return closeWithBadPayload(ws, "Did not find newline to delimit JSON from binary stream. JSON payload may be too large, or newline may be missing.");
}
binaryStream = rawData.subarray(jsonOffset, rawData.byteLength);
stringData = jsonSlice.toString();
} else {
stringData = rawData.toString();
}
if (stringData.length > maxGatewayJsonStringLength) {
return closeWithError(ws, gatewayErrors.PAYLOAD_TOO_LARGE);
}
const payload = ensureFormattedGatewayPayload(parseJsonOrNull(stringData));
if (!payload) {
return closeWithBadPayload(ws, "Invalid JSON or message does not match schema");
}
switch (payload.t) {
case GatewayPayloadType.Authenticate: {
if (ws.state.ready) {
return closeWithError(ws, gatewayErrors.ALREADY_AUTHENTICATED);
}
const authData = payload.d;
if (typeof authData !== "object") {
return closeWithBadPayload(ws, "d: expected object");
}
if (typeof authData.token !== "string") {
return closeWithBadPayload(ws, "d: invalid field 'token'");
}
if (typeof authData.bridgesTo !== "undefined" && typeof authData.bridgesTo !== "string" && authData.bridgesTo.length > 40) {
return closeWithBadPayload(ws, "d: invalid field 'bridgesTo'");
}
if (typeof authData.privacy !== "undefined" && typeof authData.privacy !== "string" && authData.privacy.length > 200) {
return closeWithBadPayload(ws, "d: invalid field 'privacy'");
}
if (typeof authData.terms !== "undefined" && typeof authData.terms !== "string" && authData.terms.length > 200) {
return closeWithBadPayload(ws, "d: invalid field 'terms'");
}
const user = await decodeTokenOrNull(authData.token);
if (!user) {
return closeWithError(ws, gatewayErrors.BAD_AUTH);
}
let sessions = sessionsByUserId.get(user.id);
if (sessions) {
if ((sessions.length + 1) > MAX_GATEWAY_SESSIONS_PER_USER) {
return closeWithError(ws, gatewayErrors.TOO_MANY_SESSIONS);
}
}
// TODO: each user should have their own list of channels that they join
const [channels, communities] = await Promise.all([
query("SELECT id, name, owner_id, community_id FROM channels ORDER BY id ASC"),
query("SELECT id, name, owner_id, avatar, created_at FROM communities ORDER BY id ASC"),
]);
if (!channels || !communities) {
return closeWithError(ws, gatewayErrors.GOT_NO_DATABASE_DATA);
}
if (!sessions) {
sessions = [];
sessionsByUserId.set(user.id, sessions);
}
sessions.push(ws);
clientSubscribe(ws, "*");
for (let i = 0; i < channels.rows.length; i++) {
clientSubscribe(ws, `channel:${channels.rows[i].id}`);
}
for (let i = 0; i < communities.rows.length; i++) {
clientSubscribe(ws, `community:${communities.rows[i].id}`);
}
ws.state.user = user;
ws.state.bridgesTo = authData.bridgesTo;
ws.state.privacy = authData.privacy;
ws.state.terms = authData.terms;
// first session, notify others that we are online
if (sessions.length === 1) {
dispatch("*", {
t: GatewayPayloadType.PresenceUpdate,
d: [getPresenceEntryForConnection(ws, GatewayPresenceStatus.Online)]
});
}
ws.state.ready = true;
sendPayload(ws, {
t: GatewayPayloadType.Ready,
d: {
user: getPublicUserObject(ws.state.user),
channels: channels.rows,
communities: communities.rows,
presence: getInitialPresenceEntries()
}
});
break;
}
case GatewayPayloadType.Ping: {
if (payload.d !== 0) {
return closeWithBadPayload(ws, "d: expected numeric '0'");
}
if (!ws.state.ready) {
return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED);
}
// TODO: also check session here
ws.state.alive = true;
break;
}
case GatewayPayloadType.RPCSignal: /* through */
case GatewayPayloadType.RPCRequest: {
if (!ws.state.ready || !ws.state.user) {
return closeWithError(ws, gatewayErrors.NOT_AUTHENTICATED);
}
// RPCSignal is like RPCRequest however it does not send RPC method output unless there is an error
processMethodBatch(ws.state.user, payload.d, (payload.t === GatewayPayloadType.RPCSignal ? true : false), binaryStream).then((results) => {
sendPayload(ws, {
t: GatewayPayloadType.RPCResponse,
d: results,
s: payload.s
});
});
break;
}
default: {
return closeWithBadPayload(ws, "t: unknown type");
}
}
});
const client = new GatewayClient(ws);
client.greet();
});
};

View file

@ -1,11 +0,0 @@
interface GatewayClientState {
user?: User;
ready: boolean,
alive: boolean,
lastAliveCheck: number,
dispatchChannels: Set<string>,
messagesSinceLastCheck: number,
bridgesTo?: string
privacy?: string,
terms?: string,
}

7
src/types/ws.d.ts vendored
View file

@ -1,7 +0,0 @@
import ws from 'ws';
declare module 'ws' {
export interface WebSocket extends ws {
state: GatewayClientState;
}
}