diff --git a/src/database/index.ts b/src/database/index.ts index 8d4ffe8..8c44b24 100644 --- a/src/database/index.ts +++ b/src/database/index.ts @@ -1,4 +1,4 @@ -import { Pool, PoolClient, QueryResult } from "pg"; +import { Client, Pool, PoolClient, QueryResult } from "pg"; const pool = new Pool(); @@ -40,3 +40,32 @@ export async function withClient(callback: (client: PoolClient) => Promise) client.release(); return result; } + +const tries = 60; + +export async function waitForDatabase(): Promise { + let success = false; + + for (let i = 0; i < tries; i++) { + const waitingClient = new Client({ connectionTimeoutMillis: 1000 }); + try { + await waitingClient.connect(); + success = true; + } catch(o_O) { + console.log("database connection failed, trying again..."); + } finally { + await waitingClient.end(); + } + + if (success) { + break; + } + } + + if (success) { + return true; + } else { + console.error(`failed to connect to database after ${tries} tries`); + return false; + } +} diff --git a/src/database/init.ts b/src/database/init.ts index a1dbb49..fc825ed 100644 --- a/src/database/init.ts +++ b/src/database/init.ts @@ -1,77 +1,77 @@ -import { query } from "."; +import { Client } from "pg"; +import { query, waitForDatabase } from "."; + +const migrationQuery = ` +CREATE TABLE IF NOT EXISTS users( + id SERIAL PRIMARY KEY, + username VARCHAR(32) UNIQUE NOT NULL, + password TEXT, + is_superuser BOOLEAN +); +CREATE TABLE IF NOT EXISTS channels( + id SERIAL PRIMARY KEY, + name VARCHAR(32) NOT NULL, + owner_id SERIAL REFERENCES users ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS messages( + id SERIAL PRIMARY KEY, + content VARCHAR(4000) NOT NULL, + channel_id SERIAL REFERENCES channels ON DELETE CASCADE, + author_id SERIAL REFERENCES users ON DELETE CASCADE, + created_at BIGINT +); +ALTER TABLE messages ADD COLUMN IF NOT EXISTS nick_username VARCHAR(64) DEFAULT NULL; +ALTER TABLE users ADD COLUMN IF NOT EXISTS avatar VARCHAR(48) DEFAULT NULL; +CREATE TABLE IF NOT EXISTS communities( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL, + owner_id SERIAL REFERENCES users ON DELETE CASCADE, + avatar VARCHAR(48) DEFAULT NULL, + created_at BIGINT +); +ALTER TABLE channels ADD COLUMN IF NOT EXISTS community_id INT NULL REFERENCES communities(id) ON DELETE CASCADE; +CREATE TABLE IF NOT EXISTS message_attachments( + id SERIAL PRIMARY KEY, + type VARCHAR(64) NOT NULL, + owner_id SERIAL REFERENCES users ON DELETE CASCADE, + message_id SERIAL REFERENCES messages ON DELETE CASCADE, + created_at BIGINT, + + file VARCHAR(256) DEFAULT NULL, + file_mime VARCHAR(256) DEFAULT NULL, + file_size_bytes BIGINT DEFAULT NULL +); +ALTER TABLE messages ADD COLUMN IF NOT EXISTS pending_attachments INT DEFAULT NULL; +ALTER TABLE message_attachments ADD COLUMN IF NOT EXISTS width INT DEFAULT NULL; +ALTER TABLE message_attachments ADD COLUMN IF NOT EXISTS height INT DEFAULT NULL; +ALTER TABLE message_attachments ADD COLUMN IF NOT EXISTS file_name VARCHAR(256) DEFAULT NULL; +`; export default async function databaseInit() { - const migrations = [ - ` - CREATE TABLE IF NOT EXISTS users( - id SERIAL PRIMARY KEY, - username VARCHAR(32) UNIQUE NOT NULL, - password TEXT, - is_superuser BOOLEAN - ); - `, - ` - CREATE TABLE IF NOT EXISTS channels( - id SERIAL PRIMARY KEY, - name VARCHAR(32) NOT NULL, - owner_id SERIAL REFERENCES users ON DELETE CASCADE - ); - `, - ` - CREATE TABLE IF NOT EXISTS messages( - id SERIAL PRIMARY KEY, - content VARCHAR(4000) NOT NULL, - channel_id SERIAL REFERENCES channels ON DELETE CASCADE, - author_id SERIAL REFERENCES users ON DELETE CASCADE, - created_at BIGINT - ); - `, - ` - ALTER TABLE messages ADD COLUMN IF NOT EXISTS nick_username VARCHAR(64) DEFAULT NULL; - `, - ` - ALTER TABLE users ADD COLUMN IF NOT EXISTS avatar VARCHAR(48) DEFAULT NULL; - `, - ` - CREATE TABLE IF NOT EXISTS communities( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL, - owner_id SERIAL REFERENCES users ON DELETE CASCADE, - avatar VARCHAR(48) DEFAULT NULL, - created_at BIGINT - ); - `, - ` - ALTER TABLE channels ADD COLUMN IF NOT EXISTS community_id INT NULL REFERENCES communities(id) ON DELETE CASCADE; - `, - ` - CREATE TABLE IF NOT EXISTS message_attachments( - id SERIAL PRIMARY KEY, - type VARCHAR(64) NOT NULL, - owner_id SERIAL REFERENCES users ON DELETE CASCADE, - message_id SERIAL REFERENCES messages ON DELETE CASCADE, - created_at BIGINT, - - file VARCHAR(256) DEFAULT NULL, - file_mime VARCHAR(256) DEFAULT NULL, - file_size_bytes BIGINT DEFAULT NULL - ); - `, - ` - ALTER TABLE messages ADD COLUMN IF NOT EXISTS pending_attachments INT DEFAULT NULL; - `, - ` - ALTER TABLE message_attachments ADD COLUMN IF NOT EXISTS width INT DEFAULT NULL; - `, - ` - ALTER TABLE message_attachments ADD COLUMN IF NOT EXISTS height INT DEFAULT NULL; - `, - ` - ALTER TABLE message_attachments ADD COLUMN IF NOT EXISTS file_name VARCHAR(256) DEFAULT NULL; - `, - ]; - - for (let i = 0; i < migrations.length; i++) { - await query(migrations[i], [], false); + const success = await waitForDatabase(); + if (!success) { + console.error("databaseInit: database is not available, exiting..."); + process.exit(1); } + + const start = performance.now(); + + const client = new Client(); + await client.connect(); + + try { + await client.query("BEGIN"); + await client.query(migrationQuery); + await client.query("COMMIT"); + } catch (O_o) { + console.error("failed to apply migrations, rolling back", O_o); + await client.query("ROLLBACK"); + process.exit(1); + } finally { + await client.end(); + } + + const delta = performance.now() - start; + + console.log(`refreshed migrations in ${delta}ms`); } \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index 7924f15..3ffe9f5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,14 +18,7 @@ function serve() { } async function main() { - if (process.argv[2] === "db-init") { - console.log("db-init: initializing database..."); - await databaseInit(); - console.log("db-init: databaseInit() finished"); - console.log("database initialized, exiting..."); - process.exit(0); - return; - } + await databaseInit(); serve(); } diff --git a/src/routes/api/v1/rpc.ts b/src/routes/api/v1/rpc.ts index 5690429..0238b95 100644 --- a/src/routes/api/v1/rpc.ts +++ b/src/routes/api/v1/rpc.ts @@ -2,6 +2,7 @@ import { errors } from "../../../errors"; import express from "express"; import { authenticateRoute } from "../../../auth"; import { processMethodBatch } from "../../../rpc/rpc"; +import { methodMap } from "../../../rpc"; const router = express.Router(); @@ -31,4 +32,12 @@ router.get( } ); +router.get( + "/methods", + authenticateRoute(false), + async (req, res) => { + res.json(methodMap); + } +); + export default router; diff --git a/src/rpc/apis/attachments.ts b/src/rpc/apis/attachments.ts index 736dcfb..c49f904 100644 --- a/src/rpc/apis/attachments.ts +++ b/src/rpc/apis/attachments.ts @@ -8,14 +8,13 @@ import path from "node:path"; import { promises as fsPromises } from "node:fs"; import { getMessageById } from "../../database/templates"; import { uploadsMode } from "../../serverconfig"; -import { dispatch } from "../../gateway"; import { GatewayPayloadType } from "../../gateway/gatewaypayloadtype"; const fileType = eval("import('file-type')"); method( "createMessageAttachment", - [uint(), string(2, 128), bufferSlice()], + [uint("messageId", "ID of the target message"), string("filenameUnsafe", "Name of the file", 2, 128), bufferSlice("inputBuffer", "File data")], async (user: User, messageId: number, filenameUnsafe: string, inputBuffer: Buffer, ctx: RPCContext) => { if (inputBuffer.byteLength >= 16777220) { return { ...errors.BAD_REQUEST, detail: "Uploaded file exceeds 16MiB limit." }; diff --git a/src/rpc/apis/channels.ts b/src/rpc/apis/channels.ts index 90647b5..1530909 100644 --- a/src/rpc/apis/channels.ts +++ b/src/rpc/apis/channels.ts @@ -2,14 +2,14 @@ import { channelNameRegex, method, int, string, uint, withOptional, withRegexp, import { query } from "../../database"; import { getMessagesByChannelFirstPage, getMessagesByChannelPage } from "../../database/templates"; import { errors } from "../../errors"; -import { dispatch, dispatchChannelSubscribe } from "../../gateway"; +import { dispatchChannelSubscribe } from "../../gateway"; import { GatewayPayloadType } from "../../gateway/gatewaypayloadtype"; import sendMessage from "../../impl"; import serverConfig from "../../serverconfig"; method( "createChannel", - [withRegexp(channelNameRegex, string(1, 32)), withOptional(uint())], + [withRegexp(channelNameRegex, string("name", "Name of the channel to create", 1, 32)), withOptional(uint("communityId", "ID of the community to attach this channel to"))], async (user: User, name: string, communityId: number | null, ctx: RPCContext) => { if (serverConfig.superuserRequirement.createChannel && !user.is_superuser) { return errors.FORBIDDEN_DUE_TO_MISSING_PERMISSIONS; @@ -35,7 +35,7 @@ method( method( "updateChannelName", - [uint(), withRegexp(channelNameRegex, string(1, 32))], + [uint("id", "ID of the channel to update"), withRegexp(channelNameRegex, string("name", "New channel name", 1, 32))], async (user: User, id: number, name: string, ctx: RPCContext) => { const permissionCheckResult = await query("SELECT owner_id FROM channels WHERE id = $1", [id]); if (!permissionCheckResult || permissionCheckResult.rowCount < 1) { @@ -61,7 +61,7 @@ method( method( "deleteChannel", - [uint()], + [uint("id", "ID of the channel to delete")], async (user: User, id: number, ctx: RPCContext) => { const permissionCheckResult = await query("SELECT owner_id FROM channels WHERE id = $1", [id]); if (!permissionCheckResult || permissionCheckResult.rowCount < 1) { @@ -87,7 +87,7 @@ method( method( "getChannel", - [uint()], + [uint("id", "ID of the channel")], async (_user: User, id: number) => { const result = await query("SELECT id, name, owner_id, community_id FROM channels WHERE id = $1", [id]); if (!result || result.rowCount < 1) { @@ -110,7 +110,13 @@ method( method( "createChannelMessage", - [uint(), string(1, 4000), withOptional(uint()), withOptional(string(1, 64)), withOptional(uint())], + [ + uint("id", "ID of the channel to send the message into"), + string("content", "Text content of the message", 1, 4000), + withOptional(uint("optimistic_id", "User-specific ID of the message to easily identify it over the gateway")), + withOptional(string("nick_username", "Username to display", 1, 64)), + withOptional(uint("pending_attachments", "Number of attachments expected to be added")) + ], async (user: User, id: number, content: string, optimistic_id: number | null, nick_username: string | null, pending_attachments: number | null, ctx: RPCContext) => { return await sendMessage(user, id, optimistic_id, content, nick_username, pending_attachments ?? 0, ctx); } @@ -118,7 +124,7 @@ method( method( "getChannelMessages", - [uint(), withOptional(int(5, 100)), withOptional(uint())], + [uint("channelId", "ID of the channel"), withOptional(int("count", "Number of messages", 5, 100)), withOptional(uint("before", "If specified, send only messages before this message ID"))], async (_user: User, channelId: number, count: number | null, before: number | null) => { let limit = count ?? 25; @@ -138,7 +144,7 @@ method( method( "putChannelTyping", - [uint()], + [uint("channelId", "ID of the channel")], async (user: User, channelId: number, ctx: RPCContext) => { ctx.gatewayDispatch(`channel:${channelId}`, { t: GatewayPayloadType.TypingStart, diff --git a/src/rpc/apis/communities.ts b/src/rpc/apis/communities.ts index 70fea4c..55e4eff 100644 --- a/src/rpc/apis/communities.ts +++ b/src/rpc/apis/communities.ts @@ -1,13 +1,13 @@ import { channelNameRegex, method, int, string, uint, withRegexp, RPCContext } from "../rpc"; import { query } from "../../database"; import { errors } from "../../errors"; -import { dispatch, dispatchChannelSubscribe } from "../../gateway"; +import { dispatchChannelSubscribe } from "../../gateway"; import { GatewayPayloadType } from "../../gateway/gatewaypayloadtype"; import serverConfig from "../../serverconfig"; method( "createCommunity", - [withRegexp(channelNameRegex, string(1, 64))], + [withRegexp(channelNameRegex, string("name", "Name of the community to create", 1, 64))], async (user: User, name: string, ctx: RPCContext) => { if (serverConfig.superuserRequirement.createChannel && !user.is_superuser) { return errors.FORBIDDEN_DUE_TO_MISSING_PERMISSIONS; @@ -30,7 +30,7 @@ method( method( "updateCommunityName", - [uint(), withRegexp(channelNameRegex, string(1, 32))], + [uint("id", "ID of the community to update"), withRegexp(channelNameRegex, string("name", "New community name", 1, 32))], async (user: User, id: number, name: string, ctx: RPCContext) => { const permissionCheckResult = await query("SELECT owner_id FROM communities WHERE id = $1", [id]); if (!permissionCheckResult || permissionCheckResult.rowCount < 1) { @@ -56,7 +56,7 @@ method( method( "deleteCommunity", - [uint()], + [uint("id", "ID of the community to delete")], async (user: User, id: number, ctx: RPCContext) => { const permissionCheckResult = await query("SELECT owner_id FROM communities WHERE id = $1", [id]); if (!permissionCheckResult || permissionCheckResult.rowCount < 1) { @@ -82,7 +82,7 @@ method( method( "getCommunity", - [uint()], + [uint("id", "ID of the community")], async (_user: User, id: number) => { const result = await query("SELECT id, name, owner_id, avatar, created_at FROM communities WHERE id = $1", [id]); if (!result || result.rowCount < 1) { @@ -105,7 +105,7 @@ method( method( "getCommunityChannels", - [uint()], + [uint("id", "ID of the community")], async (_user: User, id: number) => { const result = await query("SELECT id, name, owner_id, community_id FROM channels WHERE community_id = $1", [id]); diff --git a/src/rpc/apis/messages.ts b/src/rpc/apis/messages.ts index 09362df..5dc00c8 100644 --- a/src/rpc/apis/messages.ts +++ b/src/rpc/apis/messages.ts @@ -2,7 +2,6 @@ import { RPCContext, method, string, uint } from "./../rpc"; import { query } from "../../database"; import { getMessageById } from "../../database/templates"; import { errors } from "../../errors"; -import { dispatch } from "../../gateway"; import { GatewayPayloadType } from "../../gateway/gatewaypayloadtype"; import { unlink } from "node:fs/promises"; import path from "node:path"; @@ -10,7 +9,7 @@ import { UploadTarget, getSafeUploadPath } from "../../uploading"; method( "deleteMessage", - [uint()], + [uint("id", "ID of the message to delete")], async (user: User, id: number, ctx: RPCContext) => { const messageCheckResult = await query(getMessageById, [id]); if (!messageCheckResult || messageCheckResult.rowCount < 1) { @@ -60,7 +59,7 @@ method( method( "updateMessageContent", - [uint(), string(1, 4000)], + [uint("id", "ID of the message to update"), string("content", "New message text content", 1, 4000)], async (user: User, id: number, content: string, ctx: RPCContext) => { const permissionCheckResult = await query(getMessageById, [id]); if (!permissionCheckResult || permissionCheckResult.rowCount < 1) { @@ -91,7 +90,7 @@ method( method( "getMessage", - [uint()], + [uint("id", "ID of the message")], async (user: User, id: number) => { const result = await query(getMessageById, [id]); if (!result || result.rowCount < 1) { diff --git a/src/rpc/apis/users.ts b/src/rpc/apis/users.ts index 462a32b..5b56923 100644 --- a/src/rpc/apis/users.ts +++ b/src/rpc/apis/users.ts @@ -1,6 +1,6 @@ import { errors } from "../../errors"; import { query } from "../../database"; -import { compare, hash, hashSync } from "bcrypt"; +import { compare, hash } from "bcrypt"; import { getPublicUserObject, loginAttempt } from "../../auth"; import { RPCContext, bufferSlice, method, methodButWarningDoesNotAuthenticate, string, usernameRegex, withRegexp } from "./../rpc"; import sharp from "sharp"; @@ -8,7 +8,6 @@ import path from "path"; import { randomBytes } from "crypto"; import { unlink } from "fs/promises"; import { GatewayPayloadType } from "../../gateway/gatewaypayloadtype"; -import { dispatch } from "../../gateway"; import { supportedImageMime } from "../../uploading"; import { avatarUploadDirectory, disableAccountCreation, superuserKey } from "../../serverconfig"; @@ -17,7 +16,7 @@ const fileType = eval("import('file-type')"); methodButWarningDoesNotAuthenticate( "createUser", - [withRegexp(usernameRegex, string(3, 32)), string(8, 1000)], + [withRegexp(usernameRegex, string("username", "Username of the user to create", 3, 32)), string("password", "Password of the user to create", 8, 1000)], async (username: string, password: string) => { if (disableAccountCreation) { return errors.FEATURE_DISABLED; @@ -43,7 +42,7 @@ methodButWarningDoesNotAuthenticate( methodButWarningDoesNotAuthenticate( "loginUser", - [withRegexp(usernameRegex, string(3, 32)), string(8, 1000)], + [withRegexp(usernameRegex, string("username", "Username of the account to log into", 3, 32)), string("password", "Password of the account to log into", 8, 1000)], async (username: string, password: string) => { const token = await loginAttempt(username, password); if (!token) { @@ -64,7 +63,7 @@ method( method( "promoteUserSelf", - [string(1, 1000)], + [string("key", "Superuser key", 1, 1000)], async (user: User, key: string) => { if (!superuserKey) { return errors.FEATURE_DISABLED; @@ -91,7 +90,7 @@ const profilePictureSizes = [ method( "putUserAvatar", - [bufferSlice()], + [bufferSlice("buffer", "Avatar file bytes")], async (user: User, buffer: Buffer, ctx: RPCContext) => { if (buffer.byteLength >= 3145728) { // buffer exceeds 3MiB diff --git a/src/rpc/index.ts b/src/rpc/index.ts index 7fce692..ed72c1f 100644 --- a/src/rpc/index.ts +++ b/src/rpc/index.ts @@ -1,4 +1,4 @@ -import { methodGroup, methodNameToId, methods } from "./rpc"; +import { RPCArgumentType, RPCMethod, methodGroup, methodNameToId, methods } from "./rpc"; methodGroup(100); import "./apis/users"; @@ -12,10 +12,40 @@ methodGroup(500); import "./apis/attachments"; -console.log("--- begin rpc method map ---") -const methodMap: any = Object.fromEntries(methodNameToId); +export const methodMap: any = Object.fromEntries(methodNameToId); for (const key of Object.keys(methodMap)) { - methodMap[key] = [methodMap[key], methods.get(methodMap[key])?.requiresAuthentication]; + const id = methodMap[key]; + const method: RPCMethod | undefined = methods.get(id); + if (!method) continue; + + const actualArguments = method.args.map((arg) => { + const safeArg: any = { ...arg }; + if (arg.regexp) { + delete safeArg.regexp; + safeArg.hasRegex = true; + } + + switch (arg.type) { + case RPCArgumentType.Buffer: { + safeArg.type = "buffer" + break; + } + case RPCArgumentType.Integer: { + safeArg.type = "int"; + break; + } + case RPCArgumentType.String: { + safeArg.type = "string"; + break; + } + } + + return safeArg; + }); + + methodMap[key] = { + id, + requiresAuthentication: method.requiresAuthentication, + arguments: actualArguments + }; } -console.log(methodMap); -console.log("--- end rpc method map ---"); diff --git a/src/rpc/rpc.ts b/src/rpc/rpc.ts index 47a9d55..109a57d 100644 --- a/src/rpc/rpc.ts +++ b/src/rpc/rpc.ts @@ -10,10 +10,10 @@ export const channelNameRegex = new RegExp(/^[a-z0-9_\- ]+$/i); const defaultStringMaxLength = 3000; const defaultMaxBufferLength = maxBufferByteLength; -export const uint = (): RPCArgument => ({ type: RPCArgumentType.Integer, minValue: 0 }); -export const int = (minValue?: number, maxValue?: number): RPCArgument => ({ type: RPCArgumentType.Integer, minValue, maxValue }); -export const string = (minLength = 0, maxLength = defaultStringMaxLength): RPCArgument => ({ type: RPCArgumentType.String, minLength, maxLength }); -export const bufferSlice = (minLength = 0, maxLength = defaultMaxBufferLength) => ({ type: RPCArgumentType.Buffer, minLength, maxLength }); +export const uint = (name: string, description: string): RPCArgument => ({ name, description, type: RPCArgumentType.Integer, minValue: 0 }); +export const int = (name: string, description: string, minValue?: number, maxValue?: number): RPCArgument => ({ name, description, type: RPCArgumentType.Integer, minValue, maxValue }); +export const string = (name: string, description: string, minLength = 0, maxLength = defaultStringMaxLength): RPCArgument => ({ name, description, type: RPCArgumentType.String, minLength, maxLength }); +export const bufferSlice = (name: string, description: string, minLength = 0, maxLength = defaultMaxBufferLength) => ({ name, description, type: RPCArgumentType.Buffer, minLength, maxLength }); export const withRegexp = (regexp: RegExp, arg: RPCArgument): RPCArgument => ({ minLength: 0, maxLength: defaultStringMaxLength, ...arg, regexp }); export const withOptional = (arg: RPCArgument): RPCArgument => ({ ...arg, isOptional: true }); @@ -32,7 +32,7 @@ const defaultRPCContext: RPCContext = { }; -enum RPCArgumentType { +export enum RPCArgumentType { Integer, String, Buffer @@ -41,6 +41,8 @@ enum RPCArgumentType { interface RPCArgument { type: RPCArgumentType isOptional?: boolean + name: string, + description: string, // strings minLength?: number // also used for buffer @@ -52,7 +54,7 @@ interface RPCArgument { maxValue?: number } -interface RPCMethod { +export interface RPCMethod { args: RPCArgument[], func: ((...args: any[]) => any) requiresAuthentication: boolean