Matrix implementation

Implement a very small subset of the Matrix protocol.
This commit is contained in:
hippoz 2022-10-09 22:46:31 +03:00
parent a0438e1d45
commit 425406c88a
Signed by: hippoz
GPG key ID: 7C52899193467641
10 changed files with 575 additions and 67 deletions

View file

@ -1,8 +1,9 @@
import { NextFunction, Request, Response } from "express";
import { JwtPayload, sign, verify } from "jsonwebtoken";
import { sign, verify } from "jsonwebtoken";
import { query } from "./database";
import { errors } from "./errors";
import serverConfig from "./serverconfig";
import { compare } from "bcrypt";
const jwtSecret = process.env.JWT_SECRET || "[generic token]";
@ -30,7 +31,7 @@ export function getPublicUserObject(user: User) {
return newUser;
}
export function signToken(userId: number) {
export function signToken(userId: number): Promise<string> {
return new Promise((resolve, reject) => {
const payload = {
id: userId,
@ -48,7 +49,7 @@ export function signToken(userId: number) {
reject(error);
return;
}
if (!encoded) {
if (!encoded || typeof encoded !== "string") {
reject("got undefined encoded value");
return;
}
@ -102,6 +103,19 @@ export async function decodeTokenOrNull(encoded: string): Promise<User | undefin
}
}
export async function loginAttempt(username: string, password: string): Promise<string | null> {
const existingUser = await query("SELECT * FROM users WHERE username = $1", [username]);
if (!existingUser || existingUser.rowCount < 1) {
return null;
}
if (!await compare(password, existingUser.rows[0].password)) {
return null;
}
return await signToken(existingUser.rows[0].id);
}
export function authenticateRoute() {
return async (req: Request, res: Response, next: NextFunction) => {
const pass = (user: User | null = null) => {
@ -117,17 +131,25 @@ export function authenticateRoute() {
next();
};
const authHeader = req.get("Authorization");
if (!authHeader) return pass();
let authHeader = req.get("Authorization");
let token: string;
if (authHeader) {
const authParts = authHeader.split(" ");
if (authParts.length !== 2) return pass();
const [ authType, authToken ] = authParts;
if (authType !== "Bearer") return pass();
if (typeof authToken !== "string") return pass();
const authParts = authHeader.split(" ");
if (authParts.length !== 2) return pass();
token = authToken;
} else {
let authToken = req.query.access_token;
if (typeof authToken !== "string") return pass();
const [ authType, authToken ] = authParts;
if (authType !== "Bearer") return pass();
if (typeof authToken !== "string") return pass();
token = authToken;
}
decodeTokenOrNull(authToken).then((decoded) => {
decodeTokenOrNull(token).then((decoded) => {
pass(decoded);
}).catch(() => {
pass(null);

View file

@ -1,4 +1,4 @@
import { Pool, QueryResult } from "pg";
import { Pool, PoolClient, QueryResult } from "pg";
const pool = new Pool();
@ -23,3 +23,19 @@ export const query = function(text: string, params: any[] = [], rejectOnError =
});
});
};
export async function withClient(callback: (client: PoolClient) => Promise<any>): Promise<any> {
const client = await pool.connect();
let result = null
try {
result = await callback(client);
} catch(o_O) {
console.error("error: exception during withClient callback, going to release client and rethrow...");
client.release();
throw o_O;
}
client.release();
return result;
}

View file

@ -1,3 +1,4 @@
export const getMessageById = "SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.id = $1";
export const getMessagesByChannelFirstPage = (limit: number) => `SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.channel_id = $1 ORDER BY id DESC LIMIT ${limit}`;
export const getMessagesByChannelPage = (limit: number) => `SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.id < $1 AND messages.channel_id = $2 ORDER BY id DESC LIMIT ${limit}`;
export const getMessagesByChannelAfterPage = (limit: number) => `SELECT messages.id, messages.content, messages.channel_id, messages.created_at, messages.author_id, users.username AS author_username FROM messages JOIN users ON messages.author_id = users.id WHERE messages.id > $1 AND messages.channel_id = $2 ORDER BY id DESC LIMIT ${limit}`;

View file

@ -19,6 +19,49 @@ const dispatchChannels = new Map<string, Set<WebSocket>>();
// mapping between a user id and the websocket sessions it has
const sessionsByUserId = new Map<number, WebSocket[]>();
// mapping between a dispatch id and a temporary handler
const dispatchTemporary = new Map<string, Set<(payload: GatewayPayload) => void>>();
export function handle(channels: string[], handler: (payload: GatewayPayload) => void): (() => any) {
channels.forEach(c => {
if (!dispatchTemporary.get(c)) {
dispatchTemporary.set(c, new Set());
}
dispatchTemporary.get(c)?.add(handler);
});
return () => {
channels.forEach(c => {
dispatchTemporary.get(c)?.delete(handler);
if (dispatchTemporary.get(c)?.size === 0) {
dispatchTemporary.delete(c);
}
});
};
}
export function waitForEvent(channels: string[], timeout: number) {
return new Promise((resolve, reject) => {
let finished = false;
let clean = () => {};
const timeoutHandle = setTimeout(() => {
if (finished) return;
finished = true;
clean();
resolve(false);
}, timeout);
clean = handle(channels, () => {
if (finished) return;
finished = true;
clearTimeout(timeoutHandle);
clean();
resolve(true);
});
});
}
function clientSubscribe(ws: WebSocket, dispatchChannel: string) {
ws.state.dispatchChannels.add(dispatchChannel);
if (!dispatchChannels.get(dispatchChannel)) {
@ -67,10 +110,10 @@ function clientUnsubscribeAll(ws: WebSocket) {
ws.state.dispatchChannels = new Set();
}
export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket) => GatewayPayload)) {
export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSocket | null) => GatewayPayload)) {
const members = dispatchChannels.get(channel);
if (!members) return;
members.forEach(e => {
if (e.state.ready) {
let data = message;
@ -80,6 +123,18 @@ export function dispatch(channel: string, message: GatewayPayload | ((ws: WebSoc
e.send(JSON.stringify(data));
}
});
const handlers = dispatchTemporary.get(channel);
if (handlers) {
handlers.forEach(e => {
let data = message;
if (typeof message === "function") {
data = message(null);
}
e(data as GatewayPayload);
});
dispatchTemporary.delete(channel);
}
}
function closeWithError(ws: WebSocket, { code, message }: { code: number, message: string }) {

42
src/impl.ts Normal file
View file

@ -0,0 +1,42 @@
import { query } from "./database";
import { dispatch } from "./gateway";
import { GatewayPayloadType } from "./gateway/gatewaypayloadtype";
export default async function sendMessage(user: User, channelId: number, optimisticId: number | null, content: string) {
const authorId = user.id;
const createdAt = Date.now().toString();
const result = await query("INSERT INTO messages(content, channel_id, author_id, created_at) VALUES ($1, $2, $3, $4) RETURNING id", [content, channelId, authorId, createdAt]);
if (!result || result.rowCount < 1) {
return null;
}
let returnObject: any = {
id: result.rows[0].id,
content,
channel_id: channelId,
author_id: authorId,
author_username: user.username,
created_at: createdAt
};
dispatch(`channel:${channelId}`, (ws) => {
let payload: any = returnObject;
if (ws && ws.state && ws.state.user && ws.state.user.id === user.id && optimisticId) {
payload = {
...payload,
optimistic_id: optimisticId
}
returnObject = {
...returnObject,
optimistic_id: optimisticId
};
}
return {
t: GatewayPayloadType.MessageCreate,
d: payload
};
});
return returnObject;
}

View file

@ -6,6 +6,7 @@ import { getMessageById, getMessagesByChannelFirstPage, getMessagesByChannelPage
import { errors } from "../../../errors";
import { dispatch, dispatchChannelSubscribe } from "../../../gateway";
import { GatewayPayloadType } from "../../../gateway/gatewaypayloadtype";
import sendMessage from "../../../impl";
import serverConfig from "../../../serverconfig";
const router = express.Router();
@ -180,47 +181,7 @@ router.post(
return res.status(400).json({ ...errors.INVALID_DATA, errors: validationErrors.array() });
}
const optimisticId = parseInt(req.body.optimistic_id);
const channelId = parseInt(req.params.id);
const { content } = req.body;
const authorId = req.user.id;
const createdAt = Date.now().toString();
const result = await query("INSERT INTO messages(content, channel_id, author_id, created_at) VALUES ($1, $2, $3, $4) RETURNING id", [content, channelId, authorId, createdAt]);
if (!result || result.rowCount < 1) {
return res.status(500).json({
...errors.GOT_NO_DATABASE_DATA
});
}
let returnObject: any = {
id: result.rows[0].id,
content,
channel_id: channelId,
author_id: authorId,
author_username: req.user.username,
created_at: createdAt
};
dispatch(`channel:${channelId}`, (ws) => {
let payload: any = returnObject;
if (ws.state && ws.state.user && ws.state.user.id === req.user.id && optimisticId) {
payload = {
...payload,
optimistic_id: optimisticId
}
returnObject = {
...returnObject,
optimistic_id: optimisticId
};
}
return {
t: GatewayPayloadType.MessageCreate,
d: payload
};
});
return res.status(201).send(returnObject);
return res.status(201).send(await sendMessage(req.user, parseInt(req.params.id), parseInt(req.body.optimistic_id), req.body.content));
}
);

View file

@ -3,7 +3,7 @@ import { query } from "../../../database";
import express from "express";
import { body, validationResult } from "express-validator";
import { compare, hash, hashSync } from "bcrypt";
import { authenticateRoute, signToken } from "../../../auth";
import { authenticateRoute, loginAttempt, signToken } from "../../../auth";
const router = express.Router();
@ -76,19 +76,11 @@ router.post(
return res.status(400).json({ ...errors.BAD_LOGIN_CREDENTIALS });
}
const { username, password } = req.body;
const existingUser = await query("SELECT * FROM users WHERE username = $1", [username]);
if (!existingUser || existingUser.rowCount < 1) {
const token = await loginAttempt(req.body.username, req.body.password);
if (!token) {
return res.status(400).json({ ...errors.BAD_LOGIN_CREDENTIALS });
}
if (!await compare(password, existingUser.rows[0].password)) {
return res.status(400).json({ ...errors.BAD_LOGIN_CREDENTIALS });
}
const token = await signToken(existingUser.rows[0].id);
return res.status(200).send({ token });
}
);

417
src/routes/matrix/index.ts Normal file
View file

@ -0,0 +1,417 @@
import express from "express";
import { body, validationResult } from "express-validator";
import { PoolClient } from "pg";
import { authenticateRoute, loginAttempt } from "../../auth";
import { query, withClient } from "../../database";
import { getMessagesByChannelAfterPage, getMessagesByChannelFirstPage } from "../../database/templates";
import { handle, waitForEvent } from "../../gateway";
import sendMessage from "../../impl";
const router = express.Router();
const matrixHomeserverBaseUrl = process.env.MATRIX_HOMESERVER_BASE_URL ? process.env.MATRIX_HOMESERVER_BASE_URL : "localhost:3000";
const matrixWaffleAppUrl = process.env.MATRIX_WAFFLE_APP_URL ? process.env.MATRIX_WAFFLE_APP_URL : "localhost:3000";
const matrixDeviceId = "xyz.hippoz.waffle.generic_matrix_client.device";
const matrixRootUser = `@waffleGateway:${matrixHomeserverBaseUrl}`;
const usernameToMatrix = (username: string) => `@${username}:${matrixHomeserverBaseUrl}`;
const inventRoomEventId = (channelId: number, kind: string) => `$${channelId}${kind}:${matrixHomeserverBaseUrl}`;
const roomToChannelId = (room?: any): number | null => {
if (typeof room !== "string") return null;
const parts = room.split(":", 1);
if (parts.length !== 1) return null;
const suffixedId = parts[0];
if (!suffixedId.startsWith("!")) return null;
const id = suffixedId.substring(1, suffixedId.length);
if (!id || id.length < 1) return null;
const numberId = parseInt(id);
if (!isFinite(numberId) || isNaN(numberId)) return null;
return numberId;
};
interface MatrixSyncCursors {
[channel_id: number]: number;
};
async function buildSyncPayload(user: User, cursors: MatrixSyncCursors, onlyOutstandingEvents: boolean, client: PoolClient, channels: Channel[]) {
const joinedChannels: any[any] = {};
let nextBatchCursors = "";
for (let i = 0; i < channels.length; i++) {
const channel = channels[i];
const roomId = `!${channel.id}:${matrixHomeserverBaseUrl}`;
let channelMessagesResult;
if (cursors[channel.id]) {
channelMessagesResult = await client.query(getMessagesByChannelAfterPage(50), [cursors[channel.id], channel.id]);
} else {
channelMessagesResult = await client.query(getMessagesByChannelFirstPage(50), [channel.id]);
}
const messages = channelMessagesResult && channelMessagesResult.rows ? channelMessagesResult.rows.reverse() : [];
const messagesTimeline = messages.map(e => ({
content: {
body: e.content,
msgtype: "m.text",
},
event_id: inventRoomEventId(e.id, "message"),
origin_server_ts: parseInt(e.created_at),
room_id: roomId,
sender: usernameToMatrix(e.author_username),
type: "m.room.message"
}));
if (messages.length > 0) {
nextBatchCursors += `${channel.id}:${messages[messages.length - 1].id};`;
} else if (cursors[channel.id]) {
nextBatchCursors += `${channel.id}:${cursors[channel.id]};`;
}
if (messages.length < 1 && onlyOutstandingEvents) {
continue;
}
joinedChannels[roomId] = {
account_data: {events: []},
ephemeral: {events: []},
state: {events:[
{
type: "m.room.member",
event_id: inventRoomEventId(channel.id, "member_join"),
origin_server_ts: 0,
room_id: roomId,
sender: usernameToMatrix(user.username),
state_key: usernameToMatrix(user.username),
content: {
membership: "join"
}
},
{
type: "m.room.name",
sender: matrixRootUser,
state_key: "",
content: {
name: channel.name
}
},
{
type: "m.room.create",
sender: matrixRootUser,
state_key: "",
content: {
creator: matrixRootUser,
"m.federate": false,
room_version: 1
}
},
]},
summary: {
"m.heroes": [
matrixRootUser
],
"m.invited_member_count": 0,
"m.joined_member_count": 0
},
timeline: {events: [
...messagesTimeline
], limited: false, prev_batch: "__prev_batch__not_implemented__"},
unread_notifications: {
highlight_count: 0,
notification_count: 0
},
};
}
return {
next_batch: nextBatchCursors,
rooms: {
invite: {},
join: joinedChannels,
knock: {},
leave: {}
}
};
}
router.use((_req, res, next) => {
res.header("Access-Control-Allow-Origin", "*");
res.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
res.header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization");
next();
});
router.get("/.well-known/matrix/client", (req, res) => {
res.json({
"m.homeserver": {
"base_url": `http://${matrixHomeserverBaseUrl}`
},
"xyz.hippoz.waffle": {
"app_url": matrixWaffleAppUrl
}
});
});
router.get("/_matrix/client/versions", (req, res) => {
res.json({
versions: [
"v1.4"
]
});
});
router.get("/_matrix/client/r0/login", (req, res) => {
res.json({
flows: [
{
type: "m.login.password"
}
]
});
});
router.post("/_matrix/client/r0/register", (req, res) => {
res.status(401).json({
errcode: "M_FORBIDDEN",
error: "Registration is not implemented"
});
});
router.post(
"/_matrix/client/r0/login",
async (req, res) => {
if (req.body.type !== "m.login.password") {
return res.status(403).json({
errcode: "M_FORBIDDEN",
error: "expected type to be 'm.login.password'"
});
}
if (typeof req.body.identifier !== "object" || req.body.identifier.type !== "m.id.user" || typeof req.body.identifier.user !== "string") {
return res.status(403).json({
errcode: "M_FORBIDDEN",
error: "Bad identifier"
});
}
if (req.body.identifier.user.length < 3 || req.body.identifier.user.length > 32) {
return res.status(403).json({
errcode: "M_FORBIDDEN",
error: "Bad username - expected string between 3 and 32 characters"
});
}
if (typeof req.body.password !== "string" || req.body.password.length > 1000 || req.body.password.length < 8) {
return res.status(403).json({
errcode: "M_FORBIDDEN",
error: "Bad password - expected string between 8 and 1000 characters"
});
}
const username = req.body.identifier.user;
const password = req.body.password;
const token = await loginAttempt(username, password);
if (!token) {
return res.status(403).json({
errcode: "M_FORBIDDEN",
error: "Invalid credentials"
});
}
return res.status(200).send({
access_token: token,
device_id: matrixDeviceId,
// FIXME: expires_in_ms
user_id: usernameToMatrix(username)
});
}
);
router.get(
"/_matrix/client/r0/pushrules",
authenticateRoute(),
(req, res) => {
res.json({global: {}});
}
);
router.post(
"/_matrix/client/r0/user/:userId/filter",
authenticateRoute(),
(req, res) => {
res.json({filter_id: "-1"});
}
);
router.get(
"/_matrix/client/r0/user/:userId/filter/:filterId",
authenticateRoute(),
(req, res) => {
res.json({});
}
);
router.put(
"/_matrix/client/r0/rooms/:roomId/send/:eventType/:txnId",
authenticateRoute(),
async (req, res) => {
const channelId = roomToChannelId(req.params.roomId);
if (channelId === null) {
return res.status(400).json({
errcode: "M_BAD_JSON",
error: "Bad room id"
});
}
const eventType = req.params.eventType;
if (typeof eventType !== "string") {
return res.status(400).json({
errcode: "M_BAD_JSON",
error: "Bad event type"
});
}
if (eventType === "m.room.message") {
if (req.body.msgtype !== "m.text") {
return res.status(400).json({
errcode: "M_BAD_JSON",
error: "msgtype can only be 'm.text'"
});
}
if (typeof req.body.body !== "string" || req.body.body.length < 1 || req.body.body.length > 2000) {
return res.status(400).json({
errcode: "M_BAD_JSON",
error: "Message body must be a string between 1 and 2000 characters"
});
}
const message = await sendMessage(req.user, channelId, null, req.body.body);
if (!message) {
return res.status(500).json({
errcode: "M_UNKNOWN",
error: "Failed to send message"
});
}
res.status(200).json({
event_id: inventRoomEventId(message.id, "message")
});
} else {
res.status(400).json({
errcode: "M_BAD_JSON",
error: "Unsupported event type"
});
}
}
);
router.get(
"/_matrix/client/r0/sync",
authenticateRoute(),
async (req, res) => {
let timeout = 0;
let since: string | null = null;
let isInitial = true;
if (typeof req.query.timeout === "string") {
timeout = parseInt(req.query.timeout);
if (!timeout || isNaN(timeout) || timeout < 5 || timeout > 60000) {
timeout = 0;
}
}
if (typeof req.query.since === "string") {
since = req.query.since;
}
if (since) {
isInitial = false;
}
await withClient((client: PoolClient) => {
return new Promise(async (resolve, reject) => {
const channelsResult = await client.query("SELECT id, name, owner_id FROM channels");
if (!channelsResult) return;
const channels = channelsResult.rows || [];
const sendPayload = async () => {
const cursors: MatrixSyncCursors = {};
if (since) {
const sinceParts = since.split(";");
if (sinceParts.length < 100) {
sinceParts.forEach((part) => {
const def = part.split(":");
if (def.length !== 2) return;
const channelId = parseInt(def[0]);
const page = parseInt(def[1]);
if (!isNaN(channelId) && isFinite(channelId) && !isNaN(page) && isFinite(page)) {
cursors[channelId] = page;
}
});
}
}
res.json(await buildSyncPayload({
id: 3,
username: "test",
is_superuser: true
}, cursors, !isInitial, client, channels));
};
if (timeout) {
let dispatchChannels = ["*"];
channels.forEach(channel => dispatchChannels.push(`channel:${channel.id}`));
await waitForEvent(dispatchChannels, timeout);
}
await sendPayload();
resolve(true);
});
});
}
);
router.get(
"/_matrix/client/r0/devices",
authenticateRoute(),
(req, res) => {
res.json({
devices: [
{
device_id: matrixDeviceId,
display_name: "Waffle Generic Matrix Device",
last_seen_ip: "0.0.0.0",
last_seen_ts: 0
}
]
});
}
);
router.post(
"/_matrix/client/r0/keys/query",
authenticateRoute(),
(req, res) => {
res.json({
device_keys: {},
fallback_keys: {},
one_time_keys: {}
});
}
);
router.post(
"/_matrix/client/r0/keys/upload",
authenticateRoute(),
(req, res) => {
res.json({
one_time_key_counts: {signed_curve25519:2000}
});
}
);
router.get(
"/_matrix/client/r0/profile/:userId",
authenticateRoute(),
(req, res) => {
res.json({
displayname: req.user.username
});
}
);
export default router;

View file

@ -2,6 +2,7 @@ import express, { Application, ErrorRequestHandler, json } from "express";
import usersRouter from "./routes/api/v1/users";
import channelsRouter from "./routes/api/v1/channels";
import messagesRouter from "./routes/api/v1/messages";
import matrixRouter from "./routes/matrix";
import { errors } from "./errors";
export default function(app: Application) {
@ -10,6 +11,7 @@ export default function(app: Application) {
app.use("/api/v1/channels", channelsRouter);
app.use("/api/v1/messages", messagesRouter);
app.use("/", express.static("frontend/public"));
app.use("/", matrixRouter);
const errorHandler: ErrorRequestHandler = (error, req, res, next) => {
console.error("error: while handling request", error);

View file

@ -1,5 +1,5 @@
interface Channel {
id: number,
name: number,
name: string,
owner_id: number
}