#include #include #include #include #include #include #include #include #define MAX_CHANNELS 128 #define MAX_CLIENTS 128 #define MAX_CLIENTS_PER_CHANNEL 64 #define CHANNEL_NAME_LENGTH 28 #define BACKLOG 10 #define DIE(info) ({perror((info)); exit(EXIT_FAILURE);}) struct client { int fd; }; struct channel { char name[CHANNEL_NAME_LENGTH]; int clients[MAX_CLIENTS_PER_CHANNEL]; int clients_count; }; struct server { int sock_fd; int fd_num; struct channel channels[MAX_CHANNELS]; struct client clients[MAX_CLIENTS]; struct pollfd fds[MAX_CLIENTS + 1]; }; // https://en.wikipedia.org/wiki/Fowler-Noll-Vo_hash_function#FNV-1a_hash uint64_t hashmap_hash(const char *bytes, size_t bytes_n, size_t map_len) { uint64_t hash = 0xcbf29ce484222325; for (size_t i = 0; i < bytes_n; i++) { hash *= 0x100000001b3; hash ^= bytes[i]; } return (hash % map_len); } int server_client_add(struct server *s, int fd) { for (int i = 0; i < MAX_CLIENTS; i++) { if (s->clients[i].fd < 0) { s->clients[i].fd = fd; s->fds[i].fd = fd; s->fds[i].events = POLLIN; return i; } } return -1; } void server_client_remove(struct server *s, int i) { if (s->clients[i].fd >= 0) { close(s->clients[i].fd); } s->clients[i].fd = -1; s->fds[i].fd = -1; s->fds[i].events = 0; s->fds[i].revents = 0; for (int channel = 0; channel < MAX_CHANNELS; channel++) { for (int client = 0; client < MAX_CLIENTS_PER_CHANNEL; client++) { if (s->channels[channel].clients[client] == i) { s->channels[channel].clients[client] = -1; s->channels[channel].clients_count--; } } } } void server_client_error(struct server *s, int i, const char *msg) { send(s->clients[i].fd, msg, strlen(msg), 0); server_client_remove(s, i); } int server_channel_create(struct server *s, const char name[CHANNEL_NAME_LENGTH]) { uint64_t bucket = hashmap_hash(name, CHANNEL_NAME_LENGTH, MAX_CHANNELS); for (uint64_t i = bucket; i < bucket + 12 && i < MAX_CHANNELS; i++) { if (s->channels[i].clients_count == 0) { strncpy(s->channels[i].name, name, CHANNEL_NAME_LENGTH); return i; } } return -1; } int server_channel_find(struct server *s, const char name[CHANNEL_NAME_LENGTH]) { uint64_t bucket = hashmap_hash(name, CHANNEL_NAME_LENGTH, MAX_CHANNELS); for (uint64_t i = bucket; i < bucket + 12 && i < MAX_CHANNELS; i++) { if (strncmp(s->channels[i].name, name, CHANNEL_NAME_LENGTH) == 0) { return i; } } return -1; } int server_channel_find_or_create(struct server *s, const char name[CHANNEL_NAME_LENGTH]) { int found = server_channel_find(s, name); return found == -1 ? server_channel_create(s, name) : found; } int server_channel_add_client(struct server *s, int channel_index, int client_index) { int *clients = s->channels[channel_index].clients; for (int i = 0; i < MAX_CLIENTS_PER_CHANNEL; i++) { if (clients[i] == client_index) { return -1; } else if (clients[i] == -1) { // empty slot clients[i] = client_index; s->channels[channel_index].clients_count++; return 0; } } return -1; } int server_channel_remove_client(struct server *s, int channel_index, int client_index) { int *clients = s->channels[channel_index].clients; for (int i = 0; i < MAX_CLIENTS_PER_CHANNEL; i++) { if (clients[i] == client_index) { clients[i] = -1; s->channels[channel_index].clients_count--; return 0; } } return -1; } int server_channel_send(struct server *s, int ci, const void *buf, size_t buf_size) { int *clients = s->channels[ci].clients; int remaining = s->channels[ci].clients_count; for (int i = 0; i < MAX_CLIENTS_PER_CHANNEL; i++) { if (remaining <= 0) { break; } if (clients[i] >= 0) { send(s->clients[clients[i]].fd, buf, buf_size, 0); remaining--; } } return 0; } void server_free(struct server *s) { if (!s) return; if (s->sock_fd) close(s->sock_fd); free(s); } struct server *server_create(const char *socket_path) { struct server *s = malloc(sizeof(struct server)); if (s == NULL) { DIE("malloc"); return NULL; } s->sock_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (s->sock_fd == -1) { free(s); DIE("socket"); return NULL; } struct sockaddr_un name; memset(&name, 0, sizeof(name)); name.sun_family = AF_UNIX; strncpy(name.sun_path, socket_path, sizeof(name.sun_path) - 1); unlink(socket_path); if (bind(s->sock_fd, (const struct sockaddr *)&name, sizeof(name)) == -1) { free(s); close(s->sock_fd); DIE("bind"); return NULL; } if (listen(s->sock_fd, BACKLOG) == -1) { free(s); close(s->sock_fd); DIE("listen"); return NULL; } for (int i = 0; i < MAX_CLIENTS; i++) { s->clients[i].fd = -1; s->fds[i].fd = -1; s->fds[i].events = 0; s->fds[i].revents = 0; } for (int i = 0; i < MAX_CHANNELS; i++) { memset(s->channels[i].name, 0, sizeof(s->channels[i].name)); memset(s->channels[i].clients, -1, sizeof(s->channels[i].clients)); s->channels[i].clients_count = 0; } s->fds[MAX_CLIENTS].fd = s->sock_fd; s->fds[MAX_CLIENTS].events = POLLIN; s->fd_num = MAX_CLIENTS + 1; return s; } int server_turn(struct server *s) { if (poll(s->fds, s->fd_num, -1) < 0) { DIE("poll"); return -1; } #ifndef _client_check #define _client_check(expr, msg) ({ if((expr)){server_client_error((s),(i),(msg));continue;} }) #endif for (int i = 0; i < s->fd_num; i++) { if (s->fds[i].revents & POLLIN) { // file descriptor ready for reading if (s->fds[i].fd == s->sock_fd) { // new connection int fd = accept(s->sock_fd, NULL, NULL); if (fd == -1) { DIE("accept"); return -1; } if (server_client_add(s, fd) == -1) { close(fd); continue; } } else { // data from socket char data[256]; memset(data, 0, sizeof(data)); ssize_t bytes = recv(s->fds[i].fd, data, sizeof(data), 0); if (bytes == -1) { DIE("recv"); return -1; } else if (bytes == 0) { // client disconnected server_client_remove(s, i); } else { const char *command = strtok(data, " "); if (strcmp(command, "listen") == 0) { char channel_name[CHANNEL_NAME_LENGTH]; strncpy(channel_name, strtok(NULL, " "), CHANNEL_NAME_LENGTH); _client_check(*channel_name == '\0', "Invalid channel name"); int channel_id = server_channel_find_or_create(s, channel_name); _client_check(channel_id < 0, "Could not find or create channel"); _client_check(server_channel_add_client(s, channel_id, i) < 0, "Could not join channel"); } else if (strcmp(command, "send") == 0) { char channel_name[CHANNEL_NAME_LENGTH]; strncpy(channel_name, strtok(NULL, " "), CHANNEL_NAME_LENGTH); _client_check(*channel_name == '\0', "Invalid channel name"); const char *payload = strtok(NULL, " "); _client_check(payload == NULL || *channel_name == '\0', "Invalid channel name"); int channel_id = server_channel_find(s, channel_name); _client_check(channel_id < 0, "Could not find channel"); _client_check(server_channel_send(s, channel_id, payload, strlen(payload)) < 0, "Could not send message"); } else { server_client_error(s, i, "Unknown command"); continue; } } } } } #ifdef _client_check #undef _client_check #endif return 0; } const char *arg_shift(int *argc, char ***argv) { if (*argc < 1) return NULL; const char *result = *argv[0]; *argc -= 1; *argv += 1; return result; } int main(int argc, char *argv[]) { arg_shift(&argc, &argv); const char *socket_path = arg_shift(&argc, &argv); socket_path = socket_path == NULL ? "/tmp/duct-unix0" : socket_path; struct server *srv = server_create(socket_path); if (srv == NULL) { fprintf(stderr, "server_create failed\n"); return EXIT_FAILURE; } printf("Listening on %s\n", socket_path); while (1) { if (server_turn(srv) < 0) { fprintf(stderr, "server_turn failed\n"); server_free(srv); return EXIT_FAILURE; } } server_free(srv); return EXIT_SUCCESS; }