diff --git a/src/auth.ts b/src/auth.ts index cd5238b..1c62eb0 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -1,8 +1,9 @@ import { NextFunction, Request, Response } from "express"; -import { JwtPayload, sign, verify } from "jsonwebtoken"; +import { sign, verify } from "jsonwebtoken"; import { query } from "./database"; import { errors } from "./errors"; import serverConfig from "./serverconfig"; +import { compare } from "bcrypt"; const jwtSecret = process.env.JWT_SECRET || "[generic token]"; @@ -30,7 +31,7 @@ export function getPublicUserObject(user: User) { return newUser; } -export function signToken(userId: number) { +export function signToken(userId: number): Promise { return new Promise((resolve, reject) => { const payload = { id: userId, @@ -48,7 +49,7 @@ export function signToken(userId: number) { reject(error); return; } - if (!encoded) { + if (!encoded || typeof encoded !== "string") { reject("got undefined encoded value"); return; } @@ -102,6 +103,19 @@ export async function decodeTokenOrNull(encoded: string): Promise { + const existingUser = await query("SELECT * FROM users WHERE username = $1", [username]); + if (!existingUser || existingUser.rowCount < 1) { + return null; + } + + if (!await compare(password, existingUser.rows[0].password)) { + return null; + } + + return await signToken(existingUser.rows[0].id); +} + export function authenticateRoute() { return async (req: Request, res: Response, next: NextFunction) => { const pass = (user: User | null = null) => { @@ -117,17 +131,25 @@ export function authenticateRoute() { next(); }; - const authHeader = req.get("Authorization"); - if (!authHeader) return pass(); + let authHeader = req.get("Authorization"); + let token: string; + if (authHeader) { + const authParts = authHeader.split(" "); + if (authParts.length !== 2) return pass(); + + const [ authType, authToken ] = authParts; + if (authType !== "Bearer") return pass(); + if (typeof authToken !== "string") return pass(); - const authParts = authHeader.split(" "); - if (authParts.length !== 2) return pass(); + token = authToken; + } else { + let authToken = req.query.access_token; + if (typeof authToken !== "string") return pass(); - const [ authType, authToken ] = authParts; - if (authType !== "Bearer") return pass(); - if (typeof authToken !== "string") return pass(); + token = authToken; + } - decodeTokenOrNull(authToken).then((decoded) => { + decodeTokenOrNull(token).then((decoded) => { pass(decoded); }).catch(() => { pass(null); diff --git a/src/database/index.ts b/src/database/index.ts index 485285f..09b3529 100644 --- a/src/database/index.ts +++ b/src/database/index.ts @@ -1,4 +1,4 @@ -import { Pool, QueryResult } from "pg"; +import { Pool, PoolClient, QueryResult } from "pg"; const pool = new Pool(); @@ -23,3 +23,19 @@ export const query = function(text: string, params: any[] = [], rejectOnError = }); }); }; + +export async function withClient(callback: (client: PoolClient) => Promise): Promise { + const client = await pool.connect(); + let result = null + + try { + result = await callback(client); + } catch(o_O) { + console.error("error: exception during withClient callback, going to release client and rethrow..."); + client.release(); + throw o_O; + } + + client.release(); + return result; +} diff --git a/src/database/templates.ts b/src/database/templates.ts index f4c41d0..eaf2d5d 100644 --- a/src/database/templates.ts +++ b/src/database/templates.ts @@ -1,3 +1,4 @@ export const getMessageById = "SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.id = $1"; export const getMessagesByChannelFirstPage = (limit: number) => `SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.channel_id = $1 ORDER BY id DESC LIMIT ${limit}`; export const getMessagesByChannelPage = (limit: number) => `SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.id < $1 AND messages.channel_id = $2 ORDER BY id DESC LIMIT ${limit}`; +export const getMessagesByChannelAfterPage = (limit: number) => `SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.id > $1 AND messages.channel_id = $2 ORDER BY id DESC LIMIT ${limit}`; diff --git a/src/gateway/index.ts b/src/gateway/index.ts index 09ef7ad..6399084 100644 --- a/src/gateway/index.ts +++ b/src/gateway/index.ts @@ -19,6 +19,49 @@ const dispatchChannels = new Map>(); // mapping between a user id and the websocket sessions it has const sessionsByUserId = new Map(); +// mapping between a dispatch id and a temporary handler +const dispatchTemporary = new Map void>>(); + +export function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) { + channels.forEach(c => { + if (!dispatchTemporary.get(c)) { + dispatchTemporary.set(c, new Set()); + } + dispatchTemporary.get(c)?.add(handler); + }); + return () => { + channels.forEach(c => { + dispatchTemporary.get(c)?.delete(handler); + if (dispatchTemporary.get(c)?.size === 0) { + dispatchTemporary.delete(c); + } + }); + }; +} + +export function waitForEvent(channels: string[], timeout: number) { + return new Promise((resolve, reject) => { + let finished = false; + let clean = () => {}; + + const timeoutHandle = setTimeout(() => { + if (finished) return; + + finished = true; + clean(); + resolve(false); + }, timeout); + clean = handle(channels, () => { + if (finished) return; + + finished = true; + clearTimeout(timeoutHandle); + clean(); + resolve(true); + }); + }); +} + function clientSubscribe(ws: WebSocket, dispatchChannel: string) { ws.state.dispatchChannels.add(dispatchChannel); if (!dispatchChannels.get(dispatchChannel)) { @@ -67,10 +110,10 @@ function clientUnsubscribeAll(ws: WebSocket) { ws.state.dispatchChannels = new Set(); } -export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket) => GatewayPayload)) { +export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket | null) => GatewayPayload)) { const members = dispatchChannels.get(channel); if (!members) return; - + members.forEach(e => { if (e.state.ready) { let data = message; @@ -80,6 +123,18 @@ export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSoc e.send(JSON.stringify(data)); } }); + + const handlers = dispatchTemporary.get(channel); + if (handlers) { + handlers.forEach(e => { + let data = message; + if (typeof message === "function") { + data = message(null); + } + e(data as GatewayPayload); + }); + dispatchTemporary.delete(channel); + } } function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) { diff --git a/src/impl.ts b/src/impl.ts new file mode 100644 index 0000000..1a3aa13 --- /dev/null +++ b/src/impl.ts @@ -0,0 +1,42 @@ +import { query } from "./database"; +import { dispatch } from "./gateway"; +import { GatewayPayloadType } from "./gateway/gatewaypayloadtype"; + +export default async function sendMessage(user: User, channelId: number, optimisticId: number | null, content: string) { + const authorId = user.id; + const createdAt = Date.now().toString(); + + const result = await query("INSERT INTO messages(content, channel_id, author_id, created_at) VALUES ($1, $2, $3, $4) RETURNING id", [content, channelId, authorId, createdAt]); + if (!result || result.rowCount < 1) { + return null; + } + + let returnObject: any = { + id: result.rows[0].id, + content, + channel_id: channelId, + author_id: authorId, + author_username: user.username, + created_at: createdAt + }; + + dispatch(`channel:${channelId}`, (ws) => { + let payload: any = returnObject; + if (ws && ws.state && ws.state.user && ws.state.user.id === user.id && optimisticId) { + payload = { + ...payload, + optimistic_id: optimisticId + } + returnObject = { + ...returnObject, + optimistic_id: optimisticId + }; + } + return { + t: GatewayPayloadType.MessageCreate, + d: payload + }; + }); + + return returnObject; +} diff --git a/src/routes/api/v1/channels.ts b/src/routes/api/v1/channels.ts index c4e4b67..947d479 100644 --- a/src/routes/api/v1/channels.ts +++ b/src/routes/api/v1/channels.ts @@ -6,6 +6,7 @@ import { getMessageById, getMessagesByChannelFirstPage, getMessagesByChannelPage import { errors } from "../../../errors"; import { dispatch, dispatchChannelSubscribe } from "../../../gateway"; import { GatewayPayloadType } from "../../../gateway/gatewaypayloadtype"; +import sendMessage from "../../../impl"; import serverConfig from "../../../serverconfig"; const router = express.Router(); @@ -180,47 +181,7 @@ router.post( return res.status(400).json({ ...errors.INVALID_DATA, errors: validationErrors.array() }); } - const optimisticId = parseInt(req.body.optimistic_id); - const channelId = parseInt(req.params.id); - const { content } = req.body; - const authorId = req.user.id; - const createdAt = Date.now().toString(); - - const result = await query("INSERT INTO messages(content, channel_id, author_id, created_at) VALUES ($1, $2, $3, $4) RETURNING id", [content, channelId, authorId, createdAt]); - if (!result || result.rowCount < 1) { - return res.status(500).json({ - ...errors.GOT_NO_DATABASE_DATA - }); - } - - let returnObject: any = { - id: result.rows[0].id, - content, - channel_id: channelId, - author_id: authorId, - author_username: req.user.username, - created_at: createdAt - }; - - dispatch(`channel:${channelId}`, (ws) => { - let payload: any = returnObject; - if (ws.state && ws.state.user && ws.state.user.id === req.user.id && optimisticId) { - payload = { - ...payload, - optimistic_id: optimisticId - } - returnObject = { - ...returnObject, - optimistic_id: optimisticId - }; - } - return { - t: GatewayPayloadType.MessageCreate, - d: payload - }; - }); - - return res.status(201).send(returnObject); + return res.status(201).send(await sendMessage(req.user, parseInt(req.params.id), parseInt(req.body.optimistic_id), req.body.content)); } ); diff --git a/src/routes/api/v1/users.ts b/src/routes/api/v1/users.ts index 5f59b88..c76fab4 100644 --- a/src/routes/api/v1/users.ts +++ b/src/routes/api/v1/users.ts @@ -3,7 +3,7 @@ import { query } from "../../../database"; import express from "express"; import { body, validationResult } from "express-validator"; import { compare, hash, hashSync } from "bcrypt"; -import { authenticateRoute, signToken } from "../../../auth"; +import { authenticateRoute, loginAttempt, signToken } from "../../../auth"; const router = express.Router(); @@ -76,19 +76,11 @@ router.post( return res.status(400).json({ ...errors.BAD_LOGIN_CREDENTIALS }); } - const { username, password } = req.body; - - const existingUser = await query("SELECT * FROM users WHERE username = $1", [username]); - if (!existingUser || existingUser.rowCount < 1) { + const token = await loginAttempt(req.body.username, req.body.password); + if (!token) { return res.status(400).json({ ...errors.BAD_LOGIN_CREDENTIALS }); } - if (!await compare(password, existingUser.rows[0].password)) { - return res.status(400).json({ ...errors.BAD_LOGIN_CREDENTIALS }); - } - - const token = await signToken(existingUser.rows[0].id); - return res.status(200).send({ token }); } ); diff --git a/src/routes/matrix/index.ts b/src/routes/matrix/index.ts new file mode 100644 index 0000000..66b86f3 --- /dev/null +++ b/src/routes/matrix/index.ts @@ -0,0 +1,417 @@ +import express from "express"; +import { body, validationResult } from "express-validator"; +import { PoolClient } from "pg"; +import { authenticateRoute, loginAttempt } from "../../auth"; +import { query, withClient } from "../../database"; +import { getMessagesByChannelAfterPage, getMessagesByChannelFirstPage } from "../../database/templates"; +import { handle, waitForEvent } from "../../gateway"; +import sendMessage from "../../impl"; + +const router = express.Router(); + +const matrixHomeserverBaseUrl = process.env.MATRIX_HOMESERVER_BASE_URL ? process.env.MATRIX_HOMESERVER_BASE_URL : "localhost:3000"; +const matrixWaffleAppUrl = process.env.MATRIX_WAFFLE_APP_URL ? process.env.MATRIX_WAFFLE_APP_URL : "localhost:3000"; +const matrixDeviceId = "xyz.hippoz.waffle.generic_matrix_client.device"; +const matrixRootUser = `@waffleGateway:${matrixHomeserverBaseUrl}`; + +const usernameToMatrix = (username: string) => `@${username}:${matrixHomeserverBaseUrl}`; +const inventRoomEventId = (channelId: number, kind: string) => `$${channelId}${kind}:${matrixHomeserverBaseUrl}`; +const roomToChannelId = (room?: any): number | null => { + if (typeof room !== "string") return null; + const parts = room.split(":", 1); + if (parts.length !== 1) return null; + const suffixedId = parts[0]; + if (!suffixedId.startsWith("!")) return null; + const id = suffixedId.substring(1, suffixedId.length); + if (!id || id.length < 1) return null; + const numberId = parseInt(id); + if (!isFinite(numberId) || isNaN(numberId)) return null; + + return numberId; +}; + +interface MatrixSyncCursors { + [channel_id: number]: number; +}; + +async function buildSyncPayload(user: User, cursors: MatrixSyncCursors, onlyOutstandingEvents: boolean, client: PoolClient, channels: Channel[]) { + const joinedChannels: any[any] = {}; + let nextBatchCursors = ""; + for (let i = 0; i < channels.length; i++) { + const channel = channels[i]; + const roomId = `!${channel.id}:${matrixHomeserverBaseUrl}`; + + let channelMessagesResult; + if (cursors[channel.id]) { + channelMessagesResult = await client.query(getMessagesByChannelAfterPage(50), [cursors[channel.id], channel.id]); + } else { + channelMessagesResult = await client.query(getMessagesByChannelFirstPage(50), [channel.id]); + } + const messages = channelMessagesResult && channelMessagesResult.rows ? channelMessagesResult.rows.reverse() : []; + const messagesTimeline = messages.map(e => ({ + content: { + body: e.content, + msgtype: "m.text", + }, + event_id: inventRoomEventId(e.id, "message"), + origin_server_ts: parseInt(e.created_at), + room_id: roomId, + sender: usernameToMatrix(e.author_username), + type: "m.room.message" + })); + + if (messages.length > 0) { + nextBatchCursors += `${channel.id}:${messages[messages.length - 1].id};`; + } else if (cursors[channel.id]) { + nextBatchCursors += `${channel.id}:${cursors[channel.id]};`; + } + + if (messages.length < 1 && onlyOutstandingEvents) { + continue; + } + + joinedChannels[roomId] = { + account_data: {events: []}, + ephemeral: {events: []}, + state: {events:[ + { + type: "m.room.member", + event_id: inventRoomEventId(channel.id, "member_join"), + origin_server_ts: 0, + room_id: roomId, + sender: usernameToMatrix(user.username), + state_key: usernameToMatrix(user.username), + content: { + membership: "join" + } + }, + { + type: "m.room.name", + sender: matrixRootUser, + state_key: "", + content: { + name: channel.name + } + }, + { + type: "m.room.create", + sender: matrixRootUser, + state_key: "", + content: { + creator: matrixRootUser, + "m.federate": false, + room_version: 1 + } + }, + ]}, + summary: { + "m.heroes": [ + matrixRootUser + ], + "m.invited_member_count": 0, + "m.joined_member_count": 0 + }, + timeline: {events: [ + ...messagesTimeline + ], limited: false, prev_batch: "__prev_batch__not_implemented__"}, + unread_notifications: { + highlight_count: 0, + notification_count: 0 + }, + }; + } + + return { + next_batch: nextBatchCursors, + rooms: { + invite: {}, + join: joinedChannels, + knock: {}, + leave: {} + } + }; +} + +router.use((_req, res, next) => { + res.header("Access-Control-Allow-Origin", "*"); + res.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"); + res.header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization"); + next(); +}); + +router.get("/.well-known/matrix/client", (req, res) => { + res.json({ + "m.homeserver": { + "base_url": `http://${matrixHomeserverBaseUrl}` + }, + "xyz.hippoz.waffle": { + "app_url": matrixWaffleAppUrl + } + }); +}); + +router.get("/_matrix/client/versions", (req, res) => { + res.json({ + versions: [ + "v1.4" + ] + }); +}); + +router.get("/_matrix/client/r0/login", (req, res) => { + res.json({ + flows: [ + { + type: "m.login.password" + } + ] + }); +}); + +router.post("/_matrix/client/r0/register", (req, res) => { + res.status(401).json({ + errcode: "M_FORBIDDEN", + error: "Registration is not implemented" + }); +}); + +router.post( + "/_matrix/client/r0/login", + async (req, res) => { + if (req.body.type !== "m.login.password") { + return res.status(403).json({ + errcode: "M_FORBIDDEN", + error: "expected type to be 'm.login.password'" + }); + } + + if (typeof req.body.identifier !== "object" || req.body.identifier.type !== "m.id.user" || typeof req.body.identifier.user !== "string") { + return res.status(403).json({ + errcode: "M_FORBIDDEN", + error: "Bad identifier" + }); + } + + if (req.body.identifier.user.length < 3 || req.body.identifier.user.length > 32) { + return res.status(403).json({ + errcode: "M_FORBIDDEN", + error: "Bad username - expected string between 3 and 32 characters" + }); + } + + if (typeof req.body.password !== "string" || req.body.password.length > 1000 || req.body.password.length < 8) { + return res.status(403).json({ + errcode: "M_FORBIDDEN", + error: "Bad password - expected string between 8 and 1000 characters" + }); + } + + const username = req.body.identifier.user; + const password = req.body.password; + + const token = await loginAttempt(username, password); + if (!token) { + return res.status(403).json({ + errcode: "M_FORBIDDEN", + error: "Invalid credentials" + }); + } + + return res.status(200).send({ + access_token: token, + device_id: matrixDeviceId, + // FIXME: expires_in_ms + user_id: usernameToMatrix(username) + }); + } +); + +router.get( + "/_matrix/client/r0/pushrules", + authenticateRoute(), + (req, res) => { + res.json({global: {}}); + } +); + +router.post( + "/_matrix/client/r0/user/:userId/filter", + authenticateRoute(), + (req, res) => { + res.json({filter_id: "-1"}); + } +); + +router.get( + "/_matrix/client/r0/user/:userId/filter/:filterId", + authenticateRoute(), + (req, res) => { + res.json({}); + } +); + +router.put( + "/_matrix/client/r0/rooms/:roomId/send/:eventType/:txnId", + authenticateRoute(), + async (req, res) => { + const channelId = roomToChannelId(req.params.roomId); + if (channelId === null) { + return res.status(400).json({ + errcode: "M_BAD_JSON", + error: "Bad room id" + }); + } + + const eventType = req.params.eventType; + if (typeof eventType !== "string") { + return res.status(400).json({ + errcode: "M_BAD_JSON", + error: "Bad event type" + }); + } + + if (eventType === "m.room.message") { + if (req.body.msgtype !== "m.text") { + return res.status(400).json({ + errcode: "M_BAD_JSON", + error: "msgtype can only be 'm.text'" + }); + } + if (typeof req.body.body !== "string" || req.body.body.length < 1 || req.body.body.length > 2000) { + return res.status(400).json({ + errcode: "M_BAD_JSON", + error: "Message body must be a string between 1 and 2000 characters" + }); + } + const message = await sendMessage(req.user, channelId, null, req.body.body); + if (!message) { + return res.status(500).json({ + errcode: "M_UNKNOWN", + error: "Failed to send message" + }); + } + + res.status(200).json({ + event_id: inventRoomEventId(message.id, "message") + }); + } else { + res.status(400).json({ + errcode: "M_BAD_JSON", + error: "Unsupported event type" + }); + } + } +); + +router.get( + "/_matrix/client/r0/sync", + authenticateRoute(), + async (req, res) => { + let timeout = 0; + let since: string | null = null; + let isInitial = true; + if (typeof req.query.timeout === "string") { + timeout = parseInt(req.query.timeout); + if (!timeout || isNaN(timeout) || timeout < 5 || timeout > 60000) { + timeout = 0; + } + } + if (typeof req.query.since === "string") { + since = req.query.since; + } + if (since) { + isInitial = false; + } + + await withClient((client: PoolClient) => { + return new Promise(async (resolve, reject) => { + const channelsResult = await client.query("SELECT id, name, owner_id FROM channels"); + if (!channelsResult) return; + const channels = channelsResult.rows || []; + + const sendPayload = async () => { + const cursors: MatrixSyncCursors = {}; + if (since) { + const sinceParts = since.split(";"); + if (sinceParts.length < 100) { + sinceParts.forEach((part) => { + const def = part.split(":"); + if (def.length !== 2) return; + const channelId = parseInt(def[0]); + const page = parseInt(def[1]); + if (!isNaN(channelId) && isFinite(channelId) && !isNaN(page) && isFinite(page)) { + cursors[channelId] = page; + } + }); + } + } + + res.json(await buildSyncPayload({ + id: 3, + username: "test", + is_superuser: true + }, cursors, !isInitial, client, channels)); + }; + + if (timeout) { + let dispatchChannels = ["*"]; + channels.forEach(channel => dispatchChannels.push(`channel:${channel.id}`)); + await waitForEvent(dispatchChannels, timeout); + } + await sendPayload(); + resolve(true); + }); + }); + } +); + +router.get( + "/_matrix/client/r0/devices", + authenticateRoute(), + (req, res) => { + res.json({ + devices: [ + { + device_id: matrixDeviceId, + display_name: "Waffle Generic Matrix Device", + last_seen_ip: "0.0.0.0", + last_seen_ts: 0 + } + ] + }); + } +); + +router.post( + "/_matrix/client/r0/keys/query", + authenticateRoute(), + (req, res) => { + res.json({ + device_keys: {}, + fallback_keys: {}, + one_time_keys: {} + }); + } +); + +router.post( + "/_matrix/client/r0/keys/upload", + authenticateRoute(), + (req, res) => { + res.json({ + one_time_key_counts: {signed_curve25519:2000} + }); + } +); + +router.get( + "/_matrix/client/r0/profile/:userId", + authenticateRoute(), + (req, res) => { + res.json({ + displayname: req.user.username + }); + } +); + +export default router; diff --git a/src/server.ts b/src/server.ts index 84a483a..b3b9b21 100644 --- a/src/server.ts +++ b/src/server.ts @@ -2,6 +2,7 @@ import express, { Application, ErrorRequestHandler, json } from "express"; import usersRouter from "./routes/api/v1/users"; import channelsRouter from "./routes/api/v1/channels"; import messagesRouter from "./routes/api/v1/messages"; +import matrixRouter from "./routes/matrix"; import { errors } from "./errors"; export default function(app: Application) { @@ -10,6 +11,7 @@ export default function(app: Application) { app.use("/api/v1/channels", channelsRouter); app.use("/api/v1/messages", messagesRouter); app.use("/", express.static("frontend/public")); + app.use("/", matrixRouter); const errorHandler: ErrorRequestHandler = (error, req, res, next) => { console.error("error: while handling request", error); diff --git a/src/types/channel.d.ts b/src/types/channel.d.ts index 73ca86b..7fad630 100644 --- a/src/types/channel.d.ts +++ b/src/types/channel.d.ts @@ -1,5 +1,5 @@ interface Channel { id: number, - name: number, + name: string, owner_id: number }